diff --git a/.bazelrc b/.bazelrc index ceba7bfdbac74d1e44aadc3010e5e84bd36ce3ee..d5432aeb1757677a3e53bbb76ada2148ca5c23c6 100644 --- a/.bazelrc +++ b/.bazelrc @@ -25,12 +25,14 @@ build --define framework_shared_object=true # If you would like to use a local MKL instead of downloading, please set the # environment variable "TF_MKL_ROOT" every time before build. build:mkl --define=build_with_mkl=true --define=enable_mkl=true +build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl -c opt # This config option is used to enable MKL-DNN open source library only, # without depending on MKL binary version. build:mkl_open_source_only --define=build_with_mkl_dnn_only=true build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true +build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0 build:download_clang --crosstool_top=@local_config_download_clang//:toolchain build:download_clang --define=using_clang=true @@ -93,9 +95,6 @@ build --define=PREFIX=/usr build --define=LIBDIR=$(PREFIX)/lib build --define=INCLUDEDIR=$(PREFIX)/include -# Disable MKL-DNN contraction kernels by default. -build --define=tensorflow_mkldnn_contraction_kernel=0 - # Default options should come above this line # Options from ./configure diff --git a/RELEASE.md b/RELEASE.md index 282430d12303bde980e19e3c3602eb91b1a54d63..0a56e6909870e398c9d6349576cd2f8e6734f072 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -849,7 +849,7 @@ answered questions, and were part of inspiring discussions. * Remove `tf.contrib.data.Iterator.from_dataset()` method. Use `Dataset.make_initializable_iterator()` instead. * Remove seldom used and unnecessary `tf.contrib.data.Iterator.dispose_op()`. -* Reorder some TFGAN loss functions in a non-backwards compatible way. +* Reorder some TF-GAN loss functions in a non-backwards compatible way. ## Known Issues * In Python 3, `Dataset.from_generator()` does not support Unicode strings. diff --git a/WORKSPACE b/WORKSPACE index 2277e83a3f67b62cf4ee1311767ee06c0549c697..1c59686f16c9eb25bf509b216d18628250157319 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -4,11 +4,11 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file" http_archive( name = "io_bazel_rules_closure", - sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", - strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1", + sha256 = "43c9b882fa921923bcba764453f4058d102bece35a37c9f6383c713004aacff1", + strip_prefix = "rules_closure-9889e2348259a5aad7e805547c1a0cf311cfcd91", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", - "https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13 + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", # 2018-12-21 ], ) diff --git a/configure.py b/configure.py index c588381d40cce95df030532c4131616e4bf16c08..bb59063f7932dd1f74e12a39db029a86268a7c87 100644 --- a/configure.py +++ b/configure.py @@ -33,7 +33,7 @@ except ImportError: from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top -_DEFAULT_CUDA_VERSION = '9.0' +_DEFAULT_CUDA_VERSION = '10.0' _DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _DEFAULT_CUDA_PATH = '/usr/local/cuda' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f07e7365d3482cde5b7bb76ebf22890150e98651..29d71c323ab5ee860ebf48c332cfd7f607f3f0c3 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -370,13 +370,21 @@ config_setting( define_values = {"tf_api_version": "2"}, ) +# This flag is defined for select statements that match both +# on 'windows' and 'api_version_2'. In this case, bazel requires +# having a flag which is a superset of these two. +config_setting( + name = "windows_and_api_version_2", + define_values = {"tf_api_version": "2"}, + values = {"cpu": "x64_windows"}, +) + package_group( name = "internal", packages = [ "-//third_party/tensorflow/python/estimator", "//learning/deepmind/...", "//learning/meta_rank/...", - "//learning/pathways/...", # While dataset C++ api requires internals "//tensorflow/...", "//tensorflow_estimator/contrib/...", "//tensorflow_fold/llgtm/...", diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 2c0a7452692e5cdb184f7f0a77eb1b646a1772d4..a93799bfe84b0f9c4743e1ad0effd6e69ad7f3f2 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -52,7 +52,7 @@ elif _tf_api_dir not in __path__: __path__.append(_tf_api_dir) # Enable TF2 behaviors -from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top +from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top _compat.enable_v2_behavior() diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 514aba1b59631f882523396aab0f4d3d5e88a893..b293d0f1534d2017efa9bfed655fa6eadc5d88de 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -62,7 +62,8 @@ if '__all__' in vars(): vars()['__all__'].append('contrib') from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top -app.flags = flags # pylint: disable=undefined-variable +from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top +app.flags = flags # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 9580215a317b1a6b1cdacbd430a1764af61be990..9f2f83920cc73028fd2372afaf303e8b1c1c64f9 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2881,6 +2881,9 @@ const char* TF_ServerTarget(TF_Server* server) { #endif } -void TF_DeleteServer(TF_Server* server) { delete server; } - +void TF_DeleteServer(TF_Server* server) { +#ifndef __ANDROID__ + delete server; +#endif +} } // end extern "C" diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 9f09ad1fc307ed96f327df289c0a45bf79325d7e..3ea1724d1e6ac4135ef6883e2bba3e56ee401a41 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -3,11 +3,10 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", - "tf_cuda_cc_test", - "tf_cc_test", "tf_copts", - "tfe_xla_copts", + "tf_cuda_cc_test", "tf_cuda_library", + "tfe_xla_copts", ) tf_cuda_library( @@ -29,6 +28,7 @@ tf_cuda_library( "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/core/common_runtime/eager:profiler", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", @@ -89,6 +89,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:profiler", "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:server_lib", @@ -204,6 +205,31 @@ tf_cuda_library( ], ) +tf_cuda_cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = [ + "c_api_experimental_test.cc", + ], + extra_copts = tfe_xla_copts(), + tags = [ + "guitar", + "multi_gpu", + ], + deps = [ + ":c_api_experimental", + ":c_api_test_util", + "//tensorflow/c:c_test_util", + "//tensorflow/cc/profiler", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/profiler:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "tape", hdrs = ["tape.h"], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 027d752f420238da867cb9d8c116640e1730caaa..d5a391a98d2bdc6f80858c6673aa2ce2fac6f49a 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -774,7 +774,7 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, if (!status->status.ok()) return; tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); - ctx->context.RunMetadataProto()->Clear(); + ctx->context.ClearRunMetadata(); } namespace { diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index 52b0824552855860dfb138f3ac9a5d3afa7dc965..ffcd5ace0b98597363abe63201bf6c328a03212f 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -83,7 +83,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( } } - if (xla::ShapeUtil::IsTuple(padded_shape)) { + if (padded_shape.IsTuple()) { if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { // Currently, the only case of XlaTensor containing a tuple shape is to // represent 64 bit ints, doubles, and complex numbers (we don't support @@ -99,7 +99,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); const xla::Shape& shape1 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); - if (xla::ShapeUtil::IsTuple(shape0) || xla::ShapeUtil::IsTuple(shape1)) { + if (shape0.IsTuple() || shape1.IsTuple()) { status->status = tensorflow::errors::InvalidArgument( "XlaTensors should not contain nested tuples. Shape: ", padded_shape.DebugString()); diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 3461d81b93594021b4e977aa60ef5d1f92e1fc5c..1ce03fb22693960627c27cd4aec58106a9ff3218 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -18,6 +18,29 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" +using tensorflow::string; + void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { op->operation.ConsumeInput(h->handle); } + +TFE_Profiler* TFE_NewProfiler(TFE_Context* ctx) { + return new TFE_Profiler(ctx); +} + +void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; } + +void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler, + TF_Buffer* buf, TF_Status* status) { + TFE_ContextAsyncWait(ctx, status); + if (!status->status.ok()) return; + string content; + status->status = profiler->profiler->SerializeToString(&content); + void* data = tensorflow::port::Malloc(content.length()); + content.copy(static_cast(data), content.length(), 0); + buf->data = data; + buf->length = content.length(); + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 4ee6c066eef7f60fd2f50a197c4dc20ed9c9f927..9eb80f521624e0116dd8ea5e4dbbf7e3d350a09c 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -25,6 +25,24 @@ extern "C" { TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); +// A profiler which will start profiling when creating the object and will stop +// when the object is destroyed. It will profile all operations run under the +// given TFE_Context. Multiple instance of it can be created, but at most one +// of them will profile for each TFE_Context. +// Thread-safety: TFE_Profiler is thread-safe. +typedef struct TFE_Profiler TFE_Profiler; + +TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(TFE_Context* ctx); +TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler); + +// The output string is a binary string of tensorflow.tfprof.ProfileProto. +// User can write the string to file for offline analysis by tfprof command-line +// tools or graphical user interface. +TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Context* ctx, + TFE_Profiler* profiler, + TF_Buffer* buf, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b4fca9d45af2d72e1c21e711b60636ab3f39714 --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_experimental.h" + +#include +#include "absl/strings/match.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/cc/profiler/profiler.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/profiler/tfprof_log.pb.h" + +using tensorflow::string; + +namespace { + +void ExecuteWithProfiling(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + TFE_Profiler* profiler = TFE_NewProfiler(ctx); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TF_Buffer* profiler_result = TF_NewBuffer(); + TFE_ProfilerSerializeToString(ctx, profiler, profiler_result, status); + TFE_DeleteProfiler(profiler); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + tensorflow::tfprof::ProfileProto profile_proto; + EXPECT_TRUE(profile_proto.ParseFromString( + {reinterpret_cast(profiler_result->data), + profiler_result->length})); + TF_DeleteBuffer(profiler_result); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + TF_DeleteStatus(status); +} +TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); } +TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); } + +} // namespace diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 67bc1bcd24605f8363d6a7c8d5d6a0836a42fc82..7b1035f631ee4ca5283da93696d619dcf7b6deba 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/eager/profiler.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" @@ -100,6 +101,13 @@ struct TFE_Op { tensorflow::EagerOperation operation; }; +struct TFE_Profiler { + TFE_Profiler(TFE_Context* ctx) + : profiler(tensorflow::EagerProfiler::Create(&ctx->context)) {} + + std::unique_ptr profiler; +}; + namespace tensorflow { // Set an AttrValue on the op. Doesn't handle the list types. void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 2a4eaecb6cf2740a522b1e849d1306ebde6c4577..673aed558ab2588d2dd1e463c836082d27ef0777 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -158,3 +158,12 @@ void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, cc_ctx->set_output(i, cc_tensor); } } + +TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return static_cast(cc_ctx->expected_output_dtype(i)); +} + +int64_t TF_StepId(TF_OpKernelContext* ctx) { + return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id(); +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index cefc30bcdf89bdc14a4406299cc29f74153e77ac..721a4aca0be1a23a6764d476578ab4a26382570d 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -111,6 +111,14 @@ TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, TF_Status* status); +// Returns the expected output data type of the ith output. If i < 0 or +// i >= TF_NumOutputs(ctx), the program aborts. +TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType( + TF_OpKernelContext* ctx, int i); + +// Returns the step ID of the given context. +TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index e659ee3c3d258a626ccf03a782ec031b5a703a48..fdeb04c84c99912a9a1f898a7e91391901a6af7d 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -41,6 +41,9 @@ static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) { struct MyCustomKernel* s = static_cast(kernel); s->compute_called = true; + if (ctx != nullptr) { + EXPECT_EQ(43, TF_StepId(ctx)); + } } static void MyDeleteFunc(void* kernel) { @@ -155,6 +158,8 @@ TEST(TestKernel, TestInputAndOutputCount) { TF_SetOutput(ctx, 24, input, s); EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + EXPECT_EQ(TF_UINT8, TF_ExpectedOutputDataType(ctx, 0)); + TF_DeleteStatus(s); if (input != nullptr) { TF_DeleteTensor(input); @@ -175,6 +180,7 @@ TEST(TestKernel, TestInputAndOutputCount) { OpKernelContext::Params p; DummyDevice dummy_device(nullptr, false); p.device = &dummy_device; + p.step_id = 43; Tensor t(tensorflow::uint8(123)); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index b9a87ba296abfc6b9d9aaeff3b3e26678e4e1b94..ba9fa93654f62b9636eafc53865c7b572307e695 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -634,6 +634,7 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 8617beec004d0fe912155f054442c5b6249bb6b5..1f8ec09e19c01d0a8b2a3761135ed53dfb2ad3b0 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -32,6 +34,8 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -513,6 +517,18 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get()); if (!s.ok()) return s; + // Create FunctionLibraryRuntime. + SessionOptions session_options; + std::vector> devices; + TF_CHECK_OK(DeviceFactory::AddDevices( + session_options, "/job:localhost/replica:0/task:0", &devices)); + OptimizerOptions opts; + auto device_mgr = absl::make_unique(std::move(devices)); + auto pflr = absl::make_unique( + device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(), + opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions( "_encapsulate", /*outside_compilation_attribute=*/"", *graph, @@ -538,7 +554,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, std::map{}}); } s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters, - graph_out.get(), lib_def.get()); + graph_out.get(), flr, lib_def.get()); if (!s.ok()) return s; GraphDef graphdef_out; @@ -941,7 +957,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"c"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1101,7 +1119,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"F"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", @@ -1112,7 +1132,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph1}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"}, @@ -1244,7 +1266,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, @@ -1269,7 +1293,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}}); @@ -1397,7 +1423,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1419,7 +1447,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"i_0_retval_retval", "I:o:0"}}); @@ -1527,7 +1557,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1615,7 +1647,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1716,7 +1750,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"f_0_retval_retval", "F:o:0"}}); @@ -1821,7 +1857,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"f_0_retval_retval", "F:o:0"}}); @@ -1949,7 +1987,9 @@ TEST(EncapsulateSubgraphsTest, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph1}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0"}, @@ -1959,7 +1999,9 @@ TEST(EncapsulateSubgraphsTest, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"h_0_retval_retval", "H:o:0"}}); @@ -2082,7 +2124,9 @@ TEST(EncapsulateSubgraphsTest, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -2092,7 +2136,9 @@ TEST(EncapsulateSubgraphsTest, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"h_0_retval_retval", "H:o:0"}}); @@ -2214,7 +2260,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -2224,7 +2272,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {}}, {{"outside_compilation_O3_host_compute"}, "XlaHostCompute", @@ -2235,7 +2285,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"key", "host_compute_channel_F1_O3"}, {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O3"}}, + {"_outside_compilation_subgraph", "O3"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {}}}, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"h_0_retval_retval", "H:o:0"}}); @@ -2354,7 +2406,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"f_0_retval_retval", "F:o:0"}}); @@ -2465,7 +2519,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"c"}}, }, {{"f_0_retval_retval", "F:o:0"}}); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 8b01768c49422b331b52a8ba31bade000c95722e..2a770c527b2fae91352fd17dacb13495a3a73f34 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" namespace tensorflow { @@ -308,6 +309,10 @@ xla::StatusOr BuildXlaHostComputeNodeDef( host_compute_builder.Attr("tpu_core", core); } + // Set input tokens. + host_compute_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + // Populate inputs. std::vector input_dtypes; TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes)); @@ -398,8 +403,8 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( } // Resets "device_ordinal" attr to placeholder value for related nodes -// (XlaRecvAtHost nodes; XlaSendFromHost nodes; If nodes containing -// XlaRecvAtHost/XlaSendFromHost). +// (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes +// containing XlaRecvAtHost/XlaSendFromHost). Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { AttrValue device_ordinal_value; device_ordinal_value.set_placeholder("device_ordinal"); @@ -429,6 +434,10 @@ Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { n->ClearAttr(attr_name); n->AddAttr(attr_name, branch_func); } + } else if (HasNodeAttr(n->def(), "device_ordinal")) { + // Function call node containing outside compilation. + n->ClearAttr("device_ordinal"); + n->AddAttr("device_ordinal", device_ordinal_value); } else { return errors::Internal("Unknown node marked with ", kXlaHasHostTransferAttrName, ": ", @@ -1217,20 +1226,129 @@ Status BuildHostGraphForWhileNode( return Status::OK(); } +// Builds host graph for func call nodes. +Status BuildHostGraphForFuncCallNode(const string& func_call_node_name, + const string& xla_cluster_name, + const string& func_call_host_func_name, + const string& host_graph_func_name, + FunctionLibraryDefinition* fld) { + Graph host_graph(fld); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg + // node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, func_call_host_func_name, fld)); + + // Step 3: build a function call node with `host_func_name`, with + // `key_placeholder` as input. + NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name), + func_call_host_func_name, fld); + call_builder.Input(key_placeholder->name(), 0, DT_STRING); + call_builder.Attr("device_ordinal", device_ordinal_value); + call_builder.Attr(kXlaHasHostTransferAttrName, true); + NodeDef call_def; + TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def)); + Status s; + Node* call_node = host_graph.AddNode(call_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, call_node, 0); + + // Convert `host_graph` to function, and add a "device_ordinal" attr. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( Graph* g, const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, - const std::map& host_compute_core, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, std::vector* host_graphs, std::vector* shape_inference_graphs, bool* has_outside_compilation) { - std::vector if_nodes, while_nodes; + std::vector if_nodes, while_nodes, func_call_nodes; for (Node* n : g->nodes()) { if (n->type_string() == "If") { if_nodes.push_back(n); } else if (n->type_string() == "While") { while_nodes.push_back(n); + } else if (fld->Contains(n->type_string())) { + func_call_nodes.push_back(n); + } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) { + // Only gradient for user-defined function should be considered as + // function call node. + NameAttrList original_func; + TF_RETURN_IF_ERROR(GetNodeAttr( + n->def(), FunctionLibraryDefinition::kFuncAttr, &original_func)); + if (fld->Contains(original_func.name())) { + func_call_nodes.push_back(n); + } + } + } + + for (Node* n : func_call_nodes) { + // Extract outside compilation for the function call. + bool func_has_outside_compilation = false; + NameAttrList func; + func.set_name(n->type_string()); + typedef protobuf::Map AttrMap; + *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end()); + string new_func_name = absl::StrCat(n->name(), "_oc"); + string host_func_name = absl::StrCat("oc_func_call_host_", n->name()); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + func, new_func_name, host_func_name, host_compute_core, flr, fld, + shape_inference_graphs, &func_has_outside_compilation)); + + // If the function call does not have outside compilation, nothing to do. + if (!func_has_outside_compilation) { + continue; } + + *has_outside_compilation = true; + + // Change `n` to call the new function directly. + NodeDefBuilder replace_builder(n->name(), new_func_name, fld); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + replace_builder.Input(e->src()->name(), e->src_output(), + e->src()->output_type(e->src_output())); + } + for (const auto& attr : n->attrs()) { + replace_builder.Attr(attr.first, attr.second); + } + NodeDef replace_def; + TF_RETURN_IF_ERROR(replace_builder.Finalize(&replace_def)); + TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, replace_def)); + replace->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + + // Build host side graph for the function call. + string oc_host_graph_name = + absl::StrCat("oc_func_host_graph_", replace->name()); + TF_RETURN_IF_ERROR( + BuildHostGraphForFuncCallNode(replace->name(), xla_cluster_name, + host_func_name, oc_host_graph_name, fld)); + + // Record the host graph. + host_graphs->push_back(oc_host_graph_name); } for (Node* n : if_nodes) { @@ -1251,12 +1369,12 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, then_branch, then_branch_xla_func_name, then_branch_host_func_name, - host_compute_core, fld, shape_inference_graphs, + host_compute_core, flr, fld, shape_inference_graphs, &then_branch_has_outside_compilation)); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, else_branch, else_branch_xla_func_name, else_branch_host_func_name, - host_compute_core, fld, shape_inference_graphs, + host_compute_core, flr, fld, shape_inference_graphs, &else_branch_has_outside_compilation)); // If then/else branch do not have outside compilation, nothing to do. @@ -1316,12 +1434,12 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( body_xla_func_name = absl::StrCat(body.name(), "_oc"); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - cond, cond_xla_func_name, cond_host_func_name, host_compute_core, fld, - shape_inference_graphs, &cond_has_outside_compilation)); + cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr, + fld, shape_inference_graphs, &cond_has_outside_compilation)); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - body, body_xla_func_name, body_host_func_name, host_compute_core, fld, - shape_inference_graphs, &body_has_outside_compilation)); + body, body_xla_func_name, body_host_func_name, host_compute_core, flr, + fld, shape_inference_graphs, &body_has_outside_compilation)); // If cond/body do not have outside compilation, nothing to do. if (!cond_has_outside_compilation && !body_has_outside_compilation) { @@ -1469,17 +1587,27 @@ Status ExtractOutsideCompilationForFunction( const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, const string& host_graph_func_name, - const std::map& host_compute_core, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, bool* has_outside_compilation) { + // Convert the function to graph. const string& func_name = func_name_attrs.name(); - const FunctionDef* fdef = fld->Find(func_name); - if (!fdef) { - return errors::Internal("Cannot find function ", func_name); - } + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR( + flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* fbody = flr->GetFunctionBody(handle); + + // Check if we have outside compilation nodes. *has_outside_compilation = false; - for (auto& node_def : fdef->node_def()) { - if (HasNodeAttr(node_def, outside_compilation_attr_name)) { + for (Node* n : fbody->graph->nodes()) { + if (HasNodeAttr(n->def(), outside_compilation_attr_name)) { *has_outside_compilation = true; break; } @@ -1487,16 +1615,6 @@ Status ExtractOutsideCompilationForFunction( // We cannot early return here, because we might have outside compilation in // If/While function body. - // Convert the function to graph. - FunctionBody* fbody = nullptr; - TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(func_name), AttrSlice(&func_name_attrs.attr()), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - // Preprocess edges between different outside compilations. They will be // restored in `ConstructHostGraph()`. TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations( @@ -1553,16 +1671,11 @@ Status ExtractOutsideCompilationForFunction( TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( graph_out.get(), n, host_compute_core)); } - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("extract_outside_compilation_for_func_after_", func_name), - *graph_out, fld); - } // Handle nodes with associated functions. TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions( graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name, - xla_cluster_name, host_compute_core, fld, + xla_cluster_name, host_compute_core, flr, fld, &outside_compilation_host_graphs, shape_inference_graphs, has_outside_compilation)); @@ -1580,20 +1693,31 @@ Status ExtractOutsideCompilationForFunction( FunctionDef updated_fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*graph_out, new_func_name, &updated_fdef)); + const FunctionDef* original_fdef = fld->Find(func_name); + if (original_fdef) { + for (const auto& attr : original_fdef->attr()) { + (*updated_fdef.mutable_attr())[attr.first] = attr.second; + } + } if (fld->Find(new_func_name)) { TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, updated_fdef)); } else { TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef)); } + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("extract_outside_compilation_for_func_after_", func_name), + *graph_out, fld); + } - return Status::OK(); + return ret_status; } Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, - FunctionLibraryDefinition* fld) { + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) { if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile("extract_outside_compilation_before", *g, fld); } @@ -1610,7 +1734,7 @@ Status ExtractOutsideCompilation( TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, func_name_attrs, func_name_attrs.name(), host_graph_func_name, - host_compute_core, fld, &shape_inference_graphs, + host_compute_core, flr, fld, &shape_inference_graphs, &has_outside_compilation)); TF_RETURN_IF_ERROR( ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n)); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h index e07e7c5dd0cd42ddd4d643d8b36583c82056bbb2..d64cc2a103ed040cbf413ac736f97f84459e869b 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -89,7 +89,7 @@ Status ExtractOutsideCompilationForFunction( const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, const string& host_graph_func_name, - const std::map& host_compute_core, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, bool* has_outside_compilation); @@ -101,7 +101,7 @@ Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, - FunctionLibraryDefinition* fld); + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index e9a89e34e0c7b04b4be34e367b2d0bf627c0061a..7c3a24feff81b21a5d2347d21fb80988bc3e6065 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" @@ -31,6 +32,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -222,7 +225,42 @@ TEST(RewriteOutsideCompilationSubgraphFnTest, ShapesInferred) { EXPECT_EQ(shapes[0].dim_size(), 1); } -TEST(ExtractOutsideCompilationForFunctionTest, Basic) { +class ExtractOutsideCompilationForFunctionTest : public ::testing::Test { + public: + void SetUp() override { + SessionOptions session_options; + std::vector> devices; + TF_CHECK_OK(DeviceFactory::AddDevices( + session_options, "/job:localhost/replica:0/task:0", &devices)); + device_mgr_ = absl::make_unique(std::move(devices)); + } + + Status ExtractOutsideCompilationTest( + const string &xla_cluster_attr_name, + const string &outside_compilation_attr_name, + const string &xla_cluster_name, const NameAttrList &func_name_attrs, + const string &new_func_name, const string &host_graph_func_name, + const std::map &host_compute_core, + FunctionLibraryDefinition *fld, + std::vector *shape_inference_graphs, + bool *has_outside_compilation) { + OptimizerOptions opts; + pflr_ = absl::make_unique( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts, + /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + return ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + func_name_attrs, new_func_name, host_graph_func_name, host_compute_core, + flr, fld, shape_inference_graphs, has_outside_compilation); + } + + private: + std::unique_ptr device_mgr_; + std::unique_ptr pflr_; +}; + +TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) { // Build the XLA computation func. // "const0" // "identity0" = "const0" (outside compilation cluster "0") @@ -256,7 +294,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( + TF_CHECK_OK(ExtractOutsideCompilationTest( "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); @@ -362,7 +400,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { } } -TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { +TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { // Build the XLA computation func. // "const0" FunctionDefLibrary fdl; @@ -384,7 +422,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( + TF_CHECK_OK(ExtractOutsideCompilationTest( "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); @@ -406,7 +444,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { EXPECT_EQ(host_graph->num_nodes(), 2); } -TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { +TEST_F(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { // Build the XLA computation func. // "const0" // "const1" (outside compilation cluster "0") @@ -432,7 +470,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( + TF_CHECK_OK(ExtractOutsideCompilationTest( "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); @@ -489,7 +527,7 @@ REGISTER_OP("XlaRecvFromHost") .Attr("key: string") .SetIsStateful(); -TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { // Build the XLA computation func. // "const0" (bool) // "const1" (int32) @@ -555,7 +593,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( + TF_CHECK_OK(ExtractOutsideCompilationTest( "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); @@ -651,7 +689,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { } } -TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { // Build the XLA computation func. // "const0" (bool) // "while0" (input = "const0", cond = "cond_fn", body = "body_fn") @@ -714,7 +752,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( + TF_CHECK_OK(ExtractOutsideCompilationTest( "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); @@ -782,4 +820,162 @@ TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { } } +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { + // Build the XLA computation func. + // "const0" (int32) + // "fn" (input = "const0") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *true_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "fn", true_fn_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + { + std::unique_ptr g(new Graph(&fld)); + + tensorflow::TensorProto tensor_proto; + tensor_proto.set_dtype(tensorflow::DT_INT32); + tensorflow::TensorShapeProto shape; + shape.add_dim()->set_size(2); + *tensor_proto.mutable_tensor_shape() = shape; + for (int i = 0; i < 2; ++i) { + tensor_proto.add_int_val(1); + } + NodeDef const_def; + TF_CHECK_OK(NodeDefBuilder("const", "Const") + .Attr("dtype", DT_INT32) + .Attr("value", tensor_proto) + .Finalize(&const_def)); + Status s; + Node *const_node = g->AddNode(const_def, &s); + TF_CHECK_OK(s); + + NodeDef fn_def; + TF_CHECK_OK(NodeDefBuilder("fn", "fn", &fld) + .Input("const", 0, DT_INT32) + .Finalize(&fn_def)); + Node *fn_node = g->AddNode(fn_def, &s); + TF_CHECK_OK(s); + g->AddEdge(const_node, 0, fn_node, 0); + + NodeDef ret_def; + TF_CHECK_OK(NodeDefBuilder("ret", "_Retval") + .Attr("index", 0) + .Attr("T", DT_INT32) + .Input("fn", 0, DT_INT32) + .Finalize(&ret_def)); + Node *ret_node = g->AddNode(ret_def, &s); + TF_CHECK_OK(s); + g->AddEdge(fn_node, 0, ret_node, 0); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + TF_CHECK_OK(fld.AddFunctionDef(*xla_fdef)); + } + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have call node for outside compilation in `fn`. + Node *call_node = node_name_index["oc_call_fn"]; + EXPECT_NE(call_node, nullptr); + + FunctionBody *call_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("oc_func_call_host_fn"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &call_fbody)); + std::unique_ptr call_fbody_deleter(call_fbody); + + // Verify we have _XlaRecvAtHost and _XlaSendFromHost nodes. + bool has_recv = false, has_send = false; + for (Node *n : call_fbody->graph->nodes()) { + if (n->type_string() == "_XlaRecvAtHost") { + has_recv = true; + } else if (n->type_string() == "_XlaSendFromHost") { + has_send = true; + } + } + EXPECT_TRUE(has_recv); + EXPECT_TRUE(has_send); + } + + // Check XLA graph. + { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + Graph *xla_graph = xla_fbody->graph; + auto node_name_index = xla_graph->BuildNodeNameIndex(); + + // Check that we have call node. + Node *fn_node = node_name_index["fn"]; + EXPECT_NE(fn_node, nullptr); + EXPECT_EQ(fn_node->type_string(), "fn_oc"); + + FunctionBody *call_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("fn_oc"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &call_fbody)); + std::unique_ptr call_fbody_deleter(call_fbody); + + // Verify we have XlaHostCompute nodes. + bool has_hc = false; + for (Node *n : call_fbody->graph->nodes()) { + if (n->type_string() == "XlaHostCompute") { + has_hc = true; + } + } + EXPECT_TRUE(has_hc); + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 6618e3a58ab7b6374ed775cd6e4e18a6a4975588..50afec020debe3aa87c317d658891d4e58762862 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -677,6 +677,11 @@ Status MarkForCompilationPass::Run( VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; + // Deadness analysis expects a graph with source and sink edges properly + // connected but sometimes the incoming graph does not follow this invariant. + // So fix up the source and sink edges before calling into deadness analysis. + FixupSourceAndSinkEdges(options.graph->get()); + std::unique_ptr deadness; { XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 3df5479a55e841380ca7b8cdd0add9fd17487091..bff4cc57ee1f3ac0fc12aaa93b1588553aec8c45 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -38,6 +38,8 @@ limitations under the License. namespace tensorflow { +constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold; + XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) : client_(client), device_type_(std::move(device_type)) {} @@ -60,7 +62,7 @@ XlaCompilationCache::~XlaCompilationCache() { // about? } -string XlaCompilationCache::DebugString() { +string XlaCompilationCache::DebugString() const { return "XLA JIT compilation cache"; } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 846d0c963dbfdf55f51120f2f138d12f5f63839b..02aa8f8839e2c033e06d043b0f17d89a08d5d9e6 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -88,7 +88,7 @@ class XlaCompilationCache : public ResourceBase { xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } - string DebugString() override; + string DebugString() const override; // Describes the types, shapes and any compile-time constant arguments // to a kernel. Key that uniquely identifies a compilation output. diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index fa02cf9cbef45188a6dc2f861ff036649ea92b03..f80cb1812f00d36ddb7c28ae0e77c58498058ef3 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -230,6 +230,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", + "//tensorflow/python:standard_ops", ], ) @@ -677,6 +678,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", + "//tensorflow/python:standard_ops", ], ) @@ -826,6 +828,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python:standard_ops", "//tensorflow/python:stateless_random_ops", ], ) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index d8123e956fac04912b4fed5bf75cc9cb55c5baf9..0366ec45fb75a21b98ebfc4bdaa903bfa908de7a 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -669,6 +669,7 @@ cc_library( name = "side_effect_util", srcs = ["side_effect_util.cc"], hdrs = ["side_effect_util.h"], + visibility = [":friends"], deps = [ "//tensorflow/core:core_cpu", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index efb75749722893100494e089c0beb96944e9f1d4..0c2bb0223905b22613a64ad54f07151f7f8590b2 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -191,6 +192,9 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, // into the functions. XlaOpKernelContext xla_op_context(op_context); + XlaContext& context = XlaContext::Get(op_context); + auto* b = context.builder(); + XlaCompiler* compiler = xla_op_context.compiler(); NameAttrList func; @@ -219,8 +223,12 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RETURN_IF_ERROR( PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + bool add_token_input_output = + HasNodeAttr(n->def(), kXlaTokenInputNodesAttrName); + XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = false; + compile_options.add_token_input_output = add_token_input_output; XlaCompiler::CompilationResult result; TF_RETURN_IF_ERROR( compiler->CompileFunction(compile_options, func, arguments, &result)); @@ -234,9 +242,19 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, } handles.push_back(expressions[i]->handle()); } - - XlaContext& context = XlaContext::Get(op_context); - auto* b = context.builder(); + if (add_token_input_output) { + std::vector token_input_nodes; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->def(), kXlaTokenInputNodesAttrName, &token_input_nodes)); + std::vector token_inputs; + for (const string& node_name : token_input_nodes) { + auto token_or = compiler->GetNodeToken(node_name); + TF_RETURN_IF_ERROR(token_or.status()); + token_inputs.push_back(token_or.ConsumeValueOrDie()); + } + xla::XlaOp token_input = xla::AfterAll(b, token_inputs); + handles.push_back(token_input); + } auto output_handle = xla::Call(b, *result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so @@ -251,6 +269,10 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, ++computation_output; } } + if (add_token_input_output) { + TF_RETURN_IF_ERROR(compiler->SetNodeToken( + n->name(), xla::GetTupleElement(output_handle, computation_output))); + } return b->first_error(); } diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 7199b9b6feb36dd45ef51f4c38463bc715fcc38a..c2b4c28d1566f5429c5d8109db94af0c3762b131 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -99,8 +99,8 @@ class CategoricalOp : public XlaOpKernel { xla::PrimitiveType xla_output_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_type(0), &xla_output_type)); - xla::XlaOp argmax = XlaHelpers::ArgMax(softmax_entries, xla_output_type, - /*axis=*/class_dimension); + xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type, + /*axis=*/class_dimension); if (num_samples == 1) { argmax = xla::Reshape(argmax, {batch_size, 1}); } diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index b0bc7640307149459a29e6b0b2e8e8132e4141c9..5b4f863f7418ecda0db502ce25fed2d0042bf3ca 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -212,8 +212,8 @@ Status ConvBackpropComputeDimensionsV2XlaShapes( XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); return ConvBackpropComputeDimensionsV2( label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, - out_backprop_tensor_shape, dilations, strides, padding, data_format, - dims); + out_backprop_tensor_shape, dilations, strides, padding, + /*explicit_paddings=*/{}, data_format, dims); } } // anonymous namespace @@ -227,6 +227,11 @@ xla::StatusOr ConvOpAttrs::Create(int num_spatial_dims, TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); + // TODO(reedwm): Support explicit padding. + if (attrs.padding == EXPLICIT) { + return errors::Unimplemented( + "XLA does not yet support Conv2D with explicit padding."); + } string data_format; TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); @@ -428,11 +433,8 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); - // The conversion logic below assumes that the data format is NHWC, so we also - // check that here. bool use_batch_group_count = - filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise && - attrs.data_format == FORMAT_NHWC; + filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise; std::vector> padding(attrs.num_spatial_dims); std::vector rhs_dilation(attrs.num_spatial_dims); diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 96ddd42e2ae04d454e4fb85628d139e17a543d2e..a31b5a2cfd75f16f416615e9dc6ecca8515d35c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -351,24 +352,26 @@ struct SuppressBodyFn { auto num_outputs_so_far = values[1]; auto iou_mask = values[2]; auto included_iou = values[3]; - auto zero_r1 = xla::ConstantR1(builder, {0}); + auto zero = xla::ConstantR0(builder, 0); // Determine if current elem is active using a slice. - auto row_idx_r1 = xla::Reshape(row_idx, {1}); - auto active_elem = xla::DynamicSlice(included_iou, row_idx_r1, {1}); + // TODO(b/118437727): The only reason we need an explicit vector is because + // some old GCCs can't deduce the right type for MakeConstSpan, and + // providing a single-value initializer list directly uses the wrong + // overload. Delete this once the deprecated overload is gone. + std::vector row_idx_vector = {row_idx}; + auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1}); active_elem = xla::Reshape(active_elem, {}); // Increment output count iff current elem is not suppressed. num_outputs_so_far = xla::Select( active_elem, num_outputs_so_far + xla::ConstantR0(builder, 1), num_outputs_so_far); // Slice out the row_idx. - auto starts = xla::ConcatInDim(builder, {row_idx_r1, zero_r1}, 0); - auto row_iou = xla::DynamicSlice(iou_mask, starts, {1, num_boxes}); + auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes}); // Remove the diagonal from consideration. An elem cannot suppress // itself. - auto update_starts = xla::ConcatInDim(builder, {zero_r1, row_idx_r1}, 0); row_iou = xla::DynamicUpdateSlice( row_iou, xla::ConstantR2FromArray2D(builder, {{false}}), - update_starts); + {zero, row_idx}); // Create a suppression by inverting polarity. row_iou = xla::Reshape(row_iou, {num_boxes}); auto supp_mask = xla::Not(row_iou); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 843b6bb4e658af16fd753c1a20b35dd3d18df027..978e9480eac5b522d1ee2d51a61841c6f1bbba0c 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/index_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" @@ -66,9 +65,9 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp input = ctx->Input(0); xla::XlaOp output; if (is_min_) { - output = XlaHelpers::ArgMin(input, index_xla_type, axis); + output = xla::ArgMin(input, index_xla_type, axis); } else { - output = XlaHelpers::ArgMax(input, index_xla_type, axis); + output = xla::ArgMax(input, index_xla_type, axis); } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 3e7e8eae6ed406003be2843305bb6b173845d6cc..30b993045c86c6d01f8eabe55986f132f8938643 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -16,9 +16,9 @@ limitations under the License. // Native XLA implementations of indexing ops. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -74,7 +74,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // shape isn't supported. if (!ctx->compiler()->options().allow_cpu_custom_calls || (input_dims != 1 && input_dims != 2)) { - xla::XlaOp output = XlaHelpers::ArgMax(ctx->Input(0), output_type, axis); + xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis); ctx->SetOutput(0, output); return; } diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 02d71a394280a67cc264450021f91ba475aef7fc..d0c5231e843aefa68490e29475ee96bd92859aac 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -146,9 +146,9 @@ class StackPushOp : public XlaOpKernel { xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); @@ -202,9 +202,9 @@ class StackPopOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); + std::vector start_indices(stack_shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; auto slice_shape = stack_shape.dim_sizes(); slice_shape[0] = 1LL; diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 2c92a585f5679242d672d0402e617ff199b94f17..dfa09b16081e93ba843a1858e68e6ff756de20c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -291,5 +291,19 @@ class ResourceScatterNdAddOp : public ResourceScatterOp { }; REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp); +class ResourceScatterNdSubOp : public ResourceScatterOp { + public: + explicit ResourceScatterNdSubOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/true, + /*combiner=*/Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Sub(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterNdSub"), ResourceScatterNdSubOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index b32683a682c3a6828e92572f444022dbfc84165f..941b04363f8386a7bdbe8c91ea34c9754592a52d 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -291,20 +291,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); - auto while_shape_or = builder->GetShape(while_result); - OP_REQUIRES_OK(ctx, while_shape_or.status()); - auto count = xla::ShapeUtil::TupleElementCount(while_shape_or.ValueOrDie()); - int max_index = body.outputs.size() + body.resource_updates.size() - 1; - OP_REQUIRES( - ctx, max_index < count, - errors::Internal("Max tuple element requested (", max_index, - ") needs to be less than tuple size (", count, ")")); - - // Sets non-variable outputs. + // Sets non-variable outputs and determine when resource variables start. + int resource_index = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { ctx->SetOutput(body.input_mapping[i], xla::GetTupleElement(while_result, i)); + ++resource_index; + } else { + break; } } if (has_token_input_output_) { @@ -326,7 +321,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); if (update.modified) { - int pos = body.outputs.size() + i; + int pos = resource_index + i; OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 4dce0a2102cf9c782850ccc7af4f14b59bd51e53..7140b6a1227a53290c3747892a55886a7f48513b 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -4,7 +4,11 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_gen_op_wrapper_py", +) cc_library( name = "xla_ops", @@ -24,3 +28,14 @@ tf_gen_op_wrapper_py( ":xla_ops", ], ) + +tf_custom_op_library( + name = "_xla_ops.so", + srcs = [ + "xla_ops.cc", + ], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + ], +) diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index fef97b98c376d9df8bbfd9cb6651216895e46bf4..9abdb04d7736e8ff5225688af4759a522d3e7fc7 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -15,6 +15,7 @@ load( "//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc", ) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") tf_py_clif_cc( name = "xla_op_registry", @@ -27,9 +28,13 @@ tf_py_clif_cc( ], ) -py_library( +tf_custom_op_py_library( name = "xla", srcs = ["xla.py"], + dso = ["//tensorflow/compiler/tf2xla/ops:_xla_ops.so"], + kernels = [ + "//tensorflow/compiler/tf2xla/ops:xla_ops", + ], deps = [ "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", "//tensorflow/compiler/xla:xla_data_proto_py", diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index ff9f1b9ccba2c4f3307890d5aac4ddb6cfaafcd9..c20d6a5fd1f3bd7dad30cb3359d13ed4609a2250 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -77,6 +77,7 @@ CreateResourceOpInfoMap() { add("ResourceScatterMin" , kReadWrite, kVariable); add("ResourceScatterMul" , kReadWrite, kVariable); add("ResourceScatterNdAdd" , kReadWrite, kVariable); + add("ResourceScatterNdSub" , kReadWrite, kVariable); add("ResourceScatterNdUpdate" , kReadWrite, kVariable); add("ResourceScatterSub" , kReadWrite, kVariable); add("ResourceScatterUpdate" , kReadWrite, kVariable); diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index b62f8e9115229ac35c657d374c68336f1168ff77..412f31adbb7df52b2d6933be054cc6d40947dc44 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -26,6 +26,49 @@ const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; const char kXlaHasHostTransferAttrName[] = "_xla_has_host_transfer"; +Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { + if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) { + return errors::InvalidArgument("Node ", node->DebugString(), + " does not have attribute ", + kXlaHasHostTransferAttrName); + } + + if (node->type_string() == "_XlaRecvAtHost" || + node->type_string() == "_XlaSendFromHost") { + node->ClearAttr("device_ordinal"); + node->AddAttr("device_ordinal", device_ordinal); + } else if (node->type_string() == "If") { + AttrValue device_ordinal_value; + device_ordinal_value.set_i(device_ordinal); + for (const string& attr_name : + std::vector{"then_branch", "else_branch"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + node->ClearAttr(attr_name); + node->AddAttr(attr_name, branch_func); + } + } else if (node->type_string() == "While") { + AttrValue device_ordinal_value; + device_ordinal_value.set_i(device_ordinal); + for (const string& attr_name : std::vector{"cond", "body"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + node->ClearAttr(attr_name); + node->AddAttr(attr_name, branch_func); + } + } else if (HasNodeAttr(node->def(), "device_ordinal")) { + // Function call node containing outside compilation. + node->ClearAttr("device_ordinal"); + node->AddAttr("device_ordinal", device_ordinal); + } else { + return errors::Internal("Unknown node type to set 'device_ordinal': ", + node->DebugString()); + } + return Status::OK(); +} + std::set CalculateTokenInputsForOutputToken(const Graph& g) { std::set results; Node* first_side_effecting_node_on_path = nullptr; diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index 7081b362c36c4785164b29003a5f89cd73bcf3af..75e1f253fb08ae61b0336a8783b7449c69197dd1 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -38,6 +38,10 @@ extern const char kXlaTokenArgNodeName[]; // This node have XlaRecvAtHost/XlaSendFromHost in its associated functions. extern const char kXlaHasHostTransferAttrName[]; +// Sets device ordinal attribute for nodes with attribute +// `kXlaHasHostTransferAttrName`. +Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal); + // Calculates side-effect dependencies for the graph's token output. // Returns a set of node names representing these dependencies. std::set CalculateTokenInputsForOutputToken(const Graph& g); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 32a430184099e8bac28182ea4c7790ea1b8266de..492010f7317d32a8a620147cd2cd9356d4f13fde 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -82,7 +82,7 @@ namespace { // compiled kernels. class DummyResourceForTest : public ResourceBase { public: - string DebugString() override { return "dummy"; } + string DebugString() const override { return "dummy"; } void Increment() { ++value_; } int Get() { return value_; } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index a69af70503376b6c0905deb8980abdc3254a6e47..6139bf3cea0790c2697130a993e92be96c81848b 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -61,7 +61,7 @@ void XlaContext::set_args(std::vector args) { XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder) : compiler_(compiler), builder_(builder) {} -string XlaContext::DebugString() { return "XLA JIT context"; } +string XlaContext::DebugString() const { return "XLA JIT context"; } void XlaContext::SetRetval(int index, const XlaExpression& expression) { if (retvals_.size() <= index) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 0767d1faac14cedb8666f6cc37175eb7b55f6158..eb4ad3fe6a14b42a4df2c73c71cb6df1331fd796 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -47,7 +47,7 @@ class XlaContext : public ResourceBase { XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder); // Virtual method defined by ResourceBase. - string DebugString() override; + string DebugString() const override; XlaCompiler* compiler() const { return compiler_; } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 00035d24b7891060adbc5ff871356b7471ab92ef..04a5d934064a9083a41cc210b48df65bbc862fff 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -34,63 +34,6 @@ limitations under the License. namespace tensorflow { -namespace { - -xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis, - bool is_min) { - xla::XlaBuilder* builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); - xla::XlaOp init_value; - xla::XlaComputation reducer; - if (is_min) { - init_value = xla::MaxValue(builder, input_shape.element_type()); - reducer = - xla::CreateScalarMinComputation(input_shape.element_type(), builder); - } else { - init_value = xla::MinValue(builder, input_shape.element_type()); - reducer = - xla::CreateScalarMaxComputation(input_shape.element_type(), builder); - } - - xla::XlaOp input_max = xla::Reduce(input, init_value, reducer, - /*dimensions_to_reduce=*/{axis}); - std::vector broadcast_dims(input_shape.rank() - 1); - std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); - std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - // Compute a mask that has 1s for elements equal to the maximum. - xla::XlaOp partial_mask = xla::ConvertElementType( - xla::Eq(input, input_max, broadcast_dims), output_type); - - // In order to make identity elements for a bitwise And, we: - // Left shift the 1 to the leftmost bit, yielding 0x10...0 - // Arithmetic right shift the 1 back to the rightmost bit, yielding - // 0xFF...F - int32 bits_in_type = - xla::ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; - xla::XlaOp shift_amount = - xla::ConstantR0WithType(builder, output_type, bits_in_type); - xla::XlaOp full_mask = xla::ShiftRightArithmetic( - xla::ShiftLeft(partial_mask, shift_amount), shift_amount); - - // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its - // index. - - const int64 axis_size = xla::ShapeUtil::GetDimension(input_shape, axis); - xla::XlaOp iota = xla::Iota(builder, output_type, axis_size); - xla::XlaOp product = - xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis}); - - // If there are multiple maximum elements, choose the one with the highest - // index. - return xla::Reduce(product, xla::MinValue(builder, output_type), - xla::CreateScalarMaxComputation(output_type, builder), - /*dimensions_to_reduce=*/{axis}); - }); -} - -} // namespace - xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); @@ -148,16 +91,6 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { return linspace; } -xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, - int axis) { - return ArgMinMax(input, output_type, axis, /*is_min=*/false); -} - -xla::XlaOp XlaHelpers::ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, - int axis) { - return ArgMinMax(input, output_type, axis, /*is_min=*/true); -} - Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, const xla::XlaOp& indices, const xla::XlaOp& on_value, diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 4858dfee55a393d04cd2af83916eeb40820ee368..490923526bd3acd4b167ccb3faff1d6c9e631131 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -53,16 +53,6 @@ class XlaHelpers { absl::Span shape, xla::Literal* output); - // Returns the argmax of `input` along `axis`. `output_type` is the type to - // use for the output. - static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, - int axis); - - // Returns the argmin of `input` along `axis`. `output_type` is the type to - // use for the output. - static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, - int axis); - // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new // axis. `indices_shape` is the shape of `indices`. `on_value` and diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 722d1376687efa1c04158e3fd9ce539aac9d0122..aa43fc8dabf81fd44fc1b7c10f5d0501b1a55af3 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -152,7 +152,7 @@ cc_library( ":status", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/stream_executor", + "//tensorflow/stream_executor/lib", ], ) @@ -717,6 +717,7 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -741,6 +742,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 97fb4b9e0ee8bba23fc979cf02fc2ab79c7df1a7..df1ee330f1eb0a4556d5ec4d333c5dc4271fc5f7 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -34,6 +34,21 @@ cc_library( ], ) +xla_test( + name = "arithmetic_test", + srcs = ["arithmetic_test.cc"], + deps = [ + ":arithmetic", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "cholesky", srcs = ["cholesky.cc"], @@ -93,7 +108,6 @@ cc_library( xla_test( name = "constants_test", srcs = ["constants_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":constants", "//tensorflow/compiler/xla:test", @@ -147,7 +161,6 @@ cc_library( xla_test( name = "math_test", srcs = ["math_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":math", "//tensorflow/compiler/xla:literal_util", @@ -181,7 +194,6 @@ cc_library( xla_test( name = "matrix_test", srcs = ["matrix_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":matrix", ":slicing", @@ -295,7 +307,6 @@ cc_library( xla_test( name = "slicing_test", srcs = ["slicing_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":slicing", "//tensorflow/compiler/xla:literal_util", @@ -324,7 +335,6 @@ cc_library( xla_test( name = "sorting_test", srcs = ["sorting_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":sorting", "//tensorflow/compiler/xla:test", @@ -352,7 +362,10 @@ cc_library( xla_test( name = "quantize_test", srcs = ["quantize_test.cc"], - tags = ["enable_for_xla_interpreter"], + # TODO(b/122119490): re-enable TAP after fixing. + tags = [ + "notap", + ], deps = [ ":quantize", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 33ff3971d7277e619db4bec99ffdf9c0c9d230ad..3b875135af29f142463ffd783bfeaadc61ada1af 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -123,4 +123,64 @@ XlaOp Any(XlaOp predicates) { }); } +namespace { + +XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { + XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + XlaOp init_value; + XlaComputation reducer; + if (is_min) { + init_value = MaxValue(builder, input_shape.element_type()); + reducer = CreateScalarMinComputation(input_shape.element_type(), builder); + } else { + init_value = MinValue(builder, input_shape.element_type()); + reducer = CreateScalarMaxComputation(input_shape.element_type(), builder); + } + + XlaOp input_max = Reduce(input, init_value, reducer, + /*dimensions_to_reduce=*/{axis}); + std::vector broadcast_dims(input_shape.rank() - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + // Compute a mask that has 1s for elements equal to the maximum. + XlaOp partial_mask = + ConvertElementType(Eq(input, input_max, broadcast_dims), output_type); + + // In order to make identity elements for a bitwise And, we: + // Left shift the 1 to the leftmost bit, yielding 0x10...0 + // Arithmetic right shift the 1 back to the rightmost bit, yielding + // 0xFF...F + int32 bits_in_type = + ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; + XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type); + XlaOp full_mask = ShiftRightArithmetic( + ShiftLeft(partial_mask, shift_amount), shift_amount); + + // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its + // index. + + const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis); + XlaOp iota = Iota(builder, output_type, axis_size); + XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis}); + + // If there are multiple maximum elements, choose the one with the highest + // index. + return Reduce(product, MinValue(builder, output_type), + CreateScalarMaxComputation(output_type, builder), + /*dimensions_to_reduce=*/{axis}); + }); +} + +} // namespace + +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/false); +} + +XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/true); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 632e8cc8bc64fad236a0226c6e93079aadde7050..d4a7812c441c351b121e5d72faf9642b06728b18 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -57,6 +57,14 @@ XlaComputation CreateScalarOrComputation(PrimitiveType type, // Note: if predicates is zero-sized, Any() vacuously returns false. XlaOp Any(XlaOp predicates); +// Returns the argmax of `input` along `axis`. `output_type` is the type to +// use for the output. +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis); + +// Returns the argmin of `input` along `axis`. `output_type` is the type to +// use for the output. +XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a13839f9db89b9c07f2465867a503ef2193f8160 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using ArithmeticTest = ClientLibraryTestBase; + +XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMin(x, S32, /*axis=*/0); + + std::vector expected = {0, 2, 2}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMin(x, S32, /*axis=*/1); + + std::vector expected = {0, 1, 2}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMax(x, S32, /*axis=*/0); + + std::vector expected = {2, 0, 1}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMax(x, S32, /*axis=*/1); + + std::vector expected = {1, 0, 0}; + ComputeAndCompareR1(&builder, expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc index e100d47922efb6655e40a19b5c4cbb2190fac6d9..414bd1494cd32f32a5c37e84119de930678a776b 100644 --- a/tensorflow/compiler/xla/client/lib/cholesky.cc +++ b/tensorflow/compiler/xla/client/lib/cholesky.cc @@ -68,29 +68,26 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { auto body_fn = [&](XlaOp i, absl::Span loop_vars, XlaBuilder* body_builder) -> StatusOr> { - Shape col_shape; - Shape row_shape; - for (int64 d : major_dims) { - row_shape.add_dimensions(d); - col_shape.add_dimensions(d); - } - row_shape.add_dimensions(1); - row_shape.add_dimensions(n); - row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = Zeros(body_builder, row_shape); - - col_shape.add_dimensions(n); - col_shape.add_dimensions(1); - col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = Zeros(body_builder, col_shape); - - std::vector mask_vector(n); - std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = ConstantR1(body_builder, mask_vector); + std::vector row_shape_dims(major_dims.begin(), major_dims.end()); + std::vector col_shape_dims(major_dims.begin(), major_dims.end()); + row_shape_dims.push_back(1); + row_shape_dims.push_back(n); + auto mask_zeros_row = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), row_shape_dims)); + + col_shape_dims.push_back(n); + col_shape_dims.push_back(1); + auto mask_zeros_col = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), col_shape_dims)); + auto mask_range_row = - Broadcast(Reshape(mask_range, {0}, {1, n}), major_dims); + Iota(body_builder, ShapeUtil::MakeShape(S32, row_shape_dims), + /*iota_dimension=*/n_dims - 1); auto mask_range_col = - Broadcast(Reshape(mask_range, {0}, {n, 1}), major_dims); + Iota(body_builder, ShapeUtil::MakeShape(S32, col_shape_dims), + /*iota_dimension=*/n_dims - 2); auto body_a = loop_vars[0]; auto body_l = loop_vars[1]; diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index 611fffba8d0544a011cb641802058c08d95cf5f7..77145ba7d4c72435450d3e33d57b2507eb84d2fc 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace xla { @@ -51,17 +52,17 @@ XlaOp SliceInMinorDims(XlaOp x, absl::Span start, XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = ConstantR1(builder, start_as_int32); TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); const int64 n_dims = shape.rank(); - TF_ASSIGN_OR_RETURN(Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return DynamicUpdateSlice(x, update, start_constant); + TF_RET_CHECK(start.size() == n_dims); + + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + std::vector start_ops(start.size()); + for (int i = 0; i < start.size(); ++i) { + start_ops[i] = ConstantR0(builder, start_as_int32[i]); + } + return DynamicUpdateSlice(x, update, start_ops); }); } @@ -90,18 +91,17 @@ std::vector ConcatVectors(absl::Span xs, return output; } -XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span starts) { +StatusOr> PrependZerosInMajorDims( + XlaOp x, absl::Span starts) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = shape.rank(); - auto zero = Reshape(ConstantR0(builder, 0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = Reshape(starts[i], {1}); - } - return ConcatInDim(builder, padded_starts, 0); - }); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + auto zero = ConstantR0(builder, 0); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = starts[i]; + } + return padded_starts; } } // namespace @@ -119,7 +119,7 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, .subspan( /*pos=*/0, /*len=*/n_dims - sizes.size()); - auto padded_starts = PrependZerosInMajorDims(x, starts); + TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts)); auto padded_sizes = ConcatVectors(major_dims, sizes); return DynamicSlice(x, padded_starts, padded_sizes); }); @@ -127,8 +127,11 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, absl::Span starts) { - auto padded_starts = PrependZerosInMajorDims(x, starts); - return DynamicUpdateSlice(x, update, padded_starts); + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts)); + return DynamicUpdateSlice(x, update, padded_starts); + }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index 27ff36c7491ab8397d46f3a49493ff2b904deb2d..0fbd138aca1e86f219d0459086fc09d20844f135 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -77,7 +77,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) { auto x = ConstantR1(&builder, inputs); xla::GetTupleElement(xla::TopK(x, kSize), 0); - std::sort(inputs.begin(), inputs.end(), std::greater()); + absl::c_sort(inputs, std::greater()); ComputeAndCompareR1(&builder, inputs, {}); } diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.cc b/tensorflow/compiler/xla/client/lib/triangular_solve.cc index 6061e64656e859adcc52c250e7775a878399f4e6..c2f31742e9eff9f325fb71160b4ec3aea928d15e 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.cc @@ -165,10 +165,10 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, // The first or last diagonal element should be set to 1 instead of -1 // though, since we never update it auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); - auto start_index = (lower) ? 0 : block_size - 1; - auto output_block = DynamicUpdateSlice( - neg_identity, pos_one, - /*start_indices=*/ConstantR1(builder, 2, start_index)); + auto start_index = ConstantR0(builder, (lower) ? 0 : block_size - 1); + auto output_block = + DynamicUpdateSlice(neg_identity, pos_one, + /*start_indices=*/{start_index, start_index}); // Broadcast diag([1, -1, -1, ...]) to every block XlaOp output = Broadcast(output_block, @@ -211,12 +211,10 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, auto body_out = GetTupleElement(input_tuple, 1); auto body_input = GetTupleElement(input_tuple, 2); - auto zero = ConstantR1(bodyb.get(), 1, 0); + auto zero = ConstantR0(bodyb.get(), 0); auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; - auto start_indices = - ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); auto input_row = - DynamicSlice(body_input, start_indices, + DynamicSlice(body_input, {zero, j, zero}, /*slice_sizes=*/{num_blocks, 1, block_size}); // We want -L21 L11^{-1} @@ -230,7 +228,7 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, precision_proto.add_operand_precision(precision); auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); - body_out = DynamicUpdateSlice(body_out, update, start_indices); + body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero}); auto next_i = i + ScalarLike(i, 1); Tuple(bodyb.get(), {next_i, body_out, body_input}); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 049cd15738a619294b19d5cf74ca514d7b4a00ad..48b5f94538f453785194bc434a91ee0a10c020c2 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -164,9 +164,8 @@ StatusOr LocalExecutable::Run( // ExecutableRunOptions.eigen_intra_op_thread_pool. // *) The thread pool used for XLA CPU ops is from // backend_->eigen_intra_op_thread_pool(). - ServiceExecutableRunOptions service_options( - run_options, backend_->StreamBorrower(), - backend_->eigen_intra_op_thread_pool()); + ServiceExecutableRunOptions service_options(run_options, + backend_->StreamBorrower()); if (executable_->dumping_snapshot()) { return ExecuteAndDump(&service_options, arguments); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 7c72cdfeb50c552017647ef4a952e139b0b4adfa..59e156eb7292fe3c28345faa7247e08e3d46c487 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -29,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -192,9 +195,9 @@ StatusOr XlaBuilder::GetProgramShape(XlaOp root) const { } void XlaBuilder::IsConstantVisitor(const int64 op_handle, - std::set* visited, + absl::flat_hash_set* visited, bool* is_constant) const { - if (visited->count(op_handle) != 0 || !*is_constant) { + if (visited->contains(op_handle) || !*is_constant) { return; } @@ -244,6 +247,29 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, int64 target_param_num, ShapeIndex target_param_index, int64 target_dim_num) { + bool param_exists = false; + for (HloInstructionProto& instr : instructions_) { + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && + instr.parameter_number() == target_param_num) { + param_exists = true; + Shape param_shape(instr.shape()); + Shape* param_shape_ptr = ¶m_shape; + for (int64 index : target_param_index) { + param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index); + } + param_shape_ptr->set_dynamic_dimension(target_dim_num, + /*is_dynamic=*/true); + *instr.mutable_shape() = param_shape.ToProto(); + } + } + + if (!param_exists) { + return InvalidArgument( + "Asked to mark parameter %lld as dynamic sized parameter, but the " + "doesn't exists", + target_param_num); + } + TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind( DynamicParameterBinding::DynamicParameter{dynamic_size_param_num, dynamic_size_param_index}, @@ -310,7 +336,10 @@ StatusOr XlaBuilder::Build(int64 root_id) { module->add_computations()->Swap(&e.second); } module->add_computations()->Swap(&entry); - + if (!input_output_aliases_.empty()) { + TF_RETURN_IF_ERROR( + PopulateInputOutputAlias(module, program_shape, input_output_aliases_)); + } *(module->mutable_dynamic_parameter_binding()) = dynamic_parameter_binding_.ToProto(); @@ -323,6 +352,35 @@ StatusOr XlaBuilder::Build(int64 root_id) { return std::move(computation); } +/* static */ Status XlaBuilder::PopulateInputOutputAlias( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases) { + HloInputOutputAliasConfig config(program_shape.result()); + for (auto& alias : input_output_aliases) { + // The HloInputOutputAliasConfig does not do parameter validation as it only + // carries the result shape. Maybe it should be constructed with a + // ProgramShape to allow full validation. We will still get an error when + // trying to compile the HLO module, but would be better to have validation + // at this stage. + if (alias.param_number >= program_shape.parameters_size()) { + return InvalidArgument("Invalid parameter number %ld (total %ld)", + alias.param_number, + program_shape.parameters_size()); + } + const Shape& parameter_shape = program_shape.parameters(alias.param_number); + if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) { + return InvalidArgument("Invalid parameter %ld index: %s", + alias.param_number, + alias.param_index.ToString().c_str()); + } + TF_RETURN_IF_ERROR(config.SetUpAlias( + alias.output_index, alias.param_number, alias.param_index, + HloInputOutputAliasConfig::AliasKind::kUserAlias)); + } + *module->mutable_input_output_alias() = config.ToProto(); + return Status::OK(); +} + StatusOr XlaBuilder::InDimBroadcast( const Shape& shape, const XlaOp& operand, absl::Span broadcast_dimensions) { @@ -659,7 +717,7 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, GetShape(start_indices)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( - operand_shape, start_indices_shape, slice_sizes)); + operand_shape, {start_indices_shape}, slice_sizes)); *instr.mutable_shape() = shape.ToProto(); for (int64 size : slice_sizes) { @@ -671,6 +729,34 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, }); } +XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + std::vector start_indices_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, + GetOperandShapes(start_indices)); + absl::c_transform(start_indices_shapes, + std::back_inserter(start_indices_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferDynamicSliceShape( + operand_shape, start_indices_shapes, slice_sizes)); + *instr.mutable_shape() = shape.ToProto(); + + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } + + std::vector operands = {operand}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); + }); +} + XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -680,13 +766,38 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferDynamicUpdateSliceShape( + operand_shape, update_shape, {start_indices_shape})); + *instr.mutable_shape() = shape.ToProto(); + + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + {operand, update, start_indices}); + }); +} + +XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); + std::vector start_indices_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, + GetOperandShapes(start_indices)); + absl::c_transform(start_indices_shapes, + std::back_inserter(start_indices_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicUpdateSliceShape( - operand_shape, update_shape, start_indices_shape)); + operand_shape, update_shape, start_indices_shapes)); *instr.mutable_shape() = shape.ToProto(); + std::vector operands = {operand, update}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, - {operand, update, start_indices}); + operands); }); } @@ -2383,7 +2494,7 @@ StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { TF_RETURN_IF_ERROR(LookUpInstruction(operand).status()); bool is_constant = true; - std::set visited; + absl::flat_hash_set visited; IsConstantVisitor(operand.handle(), &visited, &is_constant); return is_constant; } @@ -2717,12 +2828,21 @@ XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } +XlaOp DynamicSlice(const XlaOp& operand, absl::Span start_indices, + absl::Span slice_sizes) { + return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); +} XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); } +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices) { + return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); +} + XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension) { return builder->ConcatInDim(operands, dimension); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 6e9b025e5d70c03e9f4c7e7fbc89976f314d48d7..8908d172fa89632ead48f954de12066af12411c7 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -276,7 +276,22 @@ class XlaBuilder { int64 target_param_num, ShapeIndex target_param_index, int64 target_dim_num); + // Adds a new input/output alias. Since the input/ouput shape information are + // not available until the computation is built, and eventual error in the + // arguments of this API will be detected only at computation Build() time. + void SetUpAlias(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + input_output_aliases_.push_back({output_index, param_number, param_index}); + } + private: + // Describes an input/output alias as inserted by the SetUpAlias() API. + struct InputOutputAlias { + ShapeIndex output_index; + int64 param_number; + ShapeIndex param_index; + }; + // Build helper which takes the id of the root operation.. StatusOr Build(int64 root_id); @@ -344,11 +359,18 @@ class XlaBuilder { XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); + ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); + XlaOp DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes); + ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); + XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); XlaOp ConcatInDim(absl::Span operands, int64 dimension); @@ -712,7 +734,8 @@ class XlaBuilder { // operation such as `RngNormal` or `Infeed`. The visitor walks the // computation starting at a given operation and sets is_constant to false iff // a parameter or stateful operation is encountered. - void IsConstantVisitor(const int64 op_handle, std::set* visited, + void IsConstantVisitor(const int64 op_handle, + absl::flat_hash_set* visited, bool* is_constant) const; // Checks bounds for convolution parameters. @@ -730,6 +753,12 @@ class XlaBuilder { int64 GetNextId() { return ++next_id_; } + // Populates the module with the input/output alias information stored within + // the input_output_aliases vector. + static Status PopulateInputOutputAlias( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases); + string name_; // Name to use for the built computation. // The next sequential ID for every instruction/computation contained within @@ -749,6 +778,9 @@ class XlaBuilder { // Dynamic parameter configuration of this computation. DynamicParameterBinding dynamic_parameter_binding_; + // Holds the input/output alias information populated by the SetUpAlias() API. + std::vector input_output_aliases_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. absl::flat_hash_map handle_to_index_; @@ -850,9 +882,14 @@ class XlaBuilder { friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); + friend XlaOp DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes); friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); + friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); friend XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension); @@ -1294,10 +1331,15 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, // The size of the slice in each dimension is passed in 'slice_sizes', // which specify the end point of exclusive slice intervals in each // dimension [start, start + size). -// The shape of 'start_indices' must be rank == 1, with dimension size -// equal to the rank of the 'operand'. +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicSlice(const XlaOp& operand, absl::Span start_indices, + absl::Span slice_sizes); + +ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); @@ -1313,10 +1355,15 @@ XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] // [7 8 9] [7 8 9 ] // -// The shape of 'start_indices' must be rank == 1, with dimension size -// equal to the rank of the 'operand'. +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. // Slice index calculations are computed modulo update dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); + +ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index b3f5be300d3f15397ad33858a6a9cab5f6029688..abc11b4732dd1c041d9c78aa7cac40129d82a785 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -446,6 +446,26 @@ TEST_F(XlaBuilderTest, ProtoMatches) { EXPECT_EQ(c0_string, c1_string); } +TEST_F(XlaBuilderTest, DynamicParameter) { + std::vector computations; + XlaBuilder b("builder"); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + Parameter(&b, 1, ShapeUtil::MakeShape(U32, {}), "p1"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/1, + /*dynamic_size_param_index=*/{}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/p0)); + const Shape& param_shape = module->entry_computation() + ->parameter_instruction(0) + ->shape() + .tuple_shapes(1); + EXPECT_TRUE(param_shape.is_dynamic_dimension(0)); +} + TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { XlaBuilder b(TestName()); AfterAll(&b, {CreateToken(&b), ConstantR0(&b, 1.0)}); @@ -455,5 +475,31 @@ TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { ::testing::HasSubstr("All operands to AfterAll must be tokens")); } +TEST_F(XlaBuilderTest, CheckInputOutputAlias) { + XlaBuilder b(TestName()); + auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0"); + auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1"); + auto add = Add(p0, p1); + auto sub = Sub(p0, p1); + auto root = Tuple(&b, {add, sub}); + + b.SetUpAlias({1}, 0, {}); + b.SetUpAlias({0}, 1, {}); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root)); + + const HloInputOutputAliasConfig& config = module->input_output_alias_config(); + EXPECT_TRUE(config.ParameterHasAlias(0, {})); + EXPECT_TRUE(config.ParameterHasAlias(1, {})); + + auto alias_p0 = config.GetAliasedOutput(0, {}); + ASSERT_TRUE(alias_p0.has_value()); + EXPECT_EQ(*alias_p0, ShapeIndex({1})); + + auto alias_p1 = config.GetAliasedOutput(1, {}); + ASSERT_TRUE(alias_p1.has_value()); + EXPECT_EQ(*alias_p1, ShapeIndex({0})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 9a9cd08c301502cbda8858225182d95fca4bf7ae..59ba9bb6584c580b237fc2de126db58453854581 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -871,9 +871,7 @@ DotGeneral performs the sum of products over contracting dimensions specified in 'dimension_numbers'. Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need -to be the same, but must be listed in the same order in both -'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes. -There must be exactly one contracting dimension on both 'lhs' and 'rhs'. +to be the same and but must have the same dimension sizes. Example with contracting dimension numbers: @@ -892,10 +890,8 @@ DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0}, {15.0, 30.0} } ``` -Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same -dimension number, must be listed in the same order in both arrays, must -have the same dimension sizes, and must be ordered before contracting and -non-contracting/non-batch dimension numbers. +Associated batch dimension numbers from the 'lhs' and 'rhs' must +have the same dimension sizes. Example with batch dimension numbers (batch size 2, 2x2 matrices): @@ -944,21 +940,21 @@ dimension: [start, start + size). The shape of `start_indices` must be rank == `DynamicSlice(operand, start_indices, size_indices)` -| Arguments | Type | Semantics | -| --------------- | ------------------- | ----------------------------------- | -| `operand` | `XlaOp` | N dimensional array of type T | -| `start_indices` | `XlaOp` | Rank 1 array of N integers | -: : : containing the starting indices of : -: : : the slice for each dimension. Value : -: : : must be greater than or equal to : -: : : zero. : -| `size_indices` | `ArraySlice` | List of N integers containing the | -: : : slice size for each dimension. Each : -: : : value must be strictly greater than : -: : : zero, and start + size must be less : -: : : than or equal to the size of the : -: : : dimension to avoid wrapping modulo : -: : : dimension size. : +| Arguments | Type | Semantics | +| --------------- | --------------------- | ---------------------------------- | +| `operand` | `XlaOp` | N dimensional array of type T | +| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | +: : : containing the starting indices of : +: : : the slice for each dimension. : +: : : Value must be greater than or : +: : : equal to zero. : +| `size_indices` | `ArraySlice` | List of N integers containing the | +: : : slice size for each dimension. : +: : : Each value must be strictly : +: : : greater than zero, and start + : +: : : size must be less than or equal to : +: : : the size of the dimension to avoid : +: : : wrapping modulo dimension size. : The effective slice indices are computed by applying the following transformation for each index `i` in `[1, N)` before performing the slice: @@ -1009,19 +1005,22 @@ the rank of `operand`. `DynamicUpdateSlice(operand, update, start_indices)` -| Arguments | Type | Semantics | -| --------------- | ------- | ------------------------------------------------ | -| `operand` | `XlaOp` | N dimensional array of type T | -| `update` | `XlaOp` | N dimensional array of type T containing the | -: : : slice update. Each dimension of update shape : -: : : must be strictly greater than zero, and start + : -: : : update must be less than or equal to the operand : -: : : size for each dimension to avoid generating : -: : : out-of-bounds update indices. : -| `start_indices` | `XlaOp` | Rank 1 array of N integers containing the | -: : : starting indices of the slice for each : -: : : dimension. Value must be greater than or equal : -: : : to zero. : +| Arguments | Type | Semantics | +| --------------- | --------------------- | ---------------------------------- | +| `operand` | `XlaOp` | N dimensional array of type T | +| `update` | `XlaOp` | N dimensional array of type T | +: : : containing the slice update. Each : +: : : dimension of update shape must be : +: : : strictly greater than zero, and : +: : : start + update must be less than : +: : : or equal to the operand size for : +: : : each dimension to avoid generating : +: : : out-of-bounds update indices. : +| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | +: : : containing the starting indices of : +: : : the slice for each dimension. : +: : : Value must be greater than or : +: : : equal to zero. : The effective slice indices are computed by applying the following transformation for each index `i` in `[1, N)` before performing the slice: @@ -1095,7 +1094,7 @@ When `Op` is `Rem`, the sign of the result is taken from the dividend, and the absolute value of the result is always less than the divisor's absolute value. Integer division overflow (signed/unsigned division/remainder by zero or signed -divison/remainder of `INT_SMIN` with `-1`) produces an implementation defined +division/remainder of `INT_SMIN` with `-1`) produces an implementation defined value. An alternative variant with different-rank broadcasting support exists for these diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 42649d669222e6a871c81a01c54148c93f12a140..2fe9b56c6bdffb931726f60ab75081361b43ebb4 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -290,8 +290,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ bool LayoutUtil::HasLayout(const Shape& shape) { if (shape.IsTuple()) { // Tuple shape: all subshapes must have a layout. - return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), - [](const Shape& s) { return HasLayout(s); }); + return absl::c_all_of(shape.tuple_shapes(), + [](const Shape& s) { return HasLayout(s); }); } else if (!shape.IsArray()) { // Opaque, token types etc. ignore layout. return true; @@ -424,7 +424,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { positions_in_layout.push_back( PositionInContainer(layout.minor_to_major(), dim)); } - std::sort(positions_in_layout.begin(), positions_in_layout.end()); + absl::c_sort(positions_in_layout); for (size_t i = 1; i < positions_in_layout.size(); ++i) { if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) { return false; diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index f4376d66af8b1eb069f5f711e4a619df9b0b7de2..258bc966b1a2ee3faa75e3319185489a107b3461 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -387,7 +387,14 @@ class NearComparator { rel_error = std::numeric_limits::infinity(); } else { abs_error = FpAbsoluteValue(actual - expected); - rel_error = abs_error / FpAbsoluteValue(expected); + + // Avoid division by 0 even though it's well-defined because ubsan can be + // configured to treat this as a fatal error. + if (expected != T{0}) { + rel_error = abs_error / FpAbsoluteValue(expected); + } else { + rel_error = std::numeric_limits::infinity(); + } } const bool is_abs_mismatch = abs_error > error_.abs; const bool is_rel_mismatch = rel_error > error_.rel; diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index 4eab4fa4290c270697c00be20840cf4e85459183..ad1699a1ae65180d56617b069d8b2e1d7d81c38c 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -55,7 +55,7 @@ string MetricTableReport::MakeReport(double expected_metric_sum) { const auto metric_greater = [](const Entry& a, const Entry& b) { return a.metric > b.metric; }; - std::sort(entries_.begin(), entries_.end(), metric_greater); + absl::c_sort(entries_, metric_greater); // Create the report AppendLine(); @@ -117,7 +117,7 @@ std::vector MetricTableReport::MakeCategories( auto metric_sum_greater = [](const Category& a, const Category& b) { return a.metric_sum > b.metric_sum; }; - std::sort(categories.begin(), categories.end(), metric_sum_greater); + absl::c_sort(categories, metric_sum_greater); return categories; } diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index b0f0e9e57060f43e464ea2b2e20b9000679d18b6..c153603105a48c473572c51013f493a45c20c4b1 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -927,6 +927,22 @@ LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a, conjugate_a); } +LocalOp LocalComputationBuilder::Gather( + const LocalOp& input, const LocalOp& start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes) { + return xla::Gather(input.op(), start_indices.op(), dimension_numbers, + slice_sizes); +} + +LocalOp LocalComputationBuilder::Scatter( + const LocalOp& input, const LocalOp& scatter_indices, + const LocalOp& updates, const LocalComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return xla::Scatter(input.op(), scatter_indices.op(), updates.op(), + update_computation.computation(), dimension_numbers); +} + StatusOr LocalComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 5e8341592100bc1eba4d1c17b0c2dd0e0888fdb1..98759cf984751d2cef8df4449d392ace786a8ebc 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -418,6 +418,15 @@ class LocalComputationBuilder { LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side, bool lower, bool transpose_a, bool conjugate_a); + LocalOp Gather(const LocalOp& input, const LocalOp& start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes); + + LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices, + const LocalOp& updates, + const LocalComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index bf5d667c6a12972845735983a74264ea05675971..66ecee5c4d44663838055065d9884a87a512b30b 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -34,6 +34,8 @@ limitations under the License. // PaddingConfig proto <- corresponding Python proto // ConvolutionDimensionNumbers proto <- corresponding Python proto // DotDimensionNumbers proto <- corresponding Python proto +// GatherDimensionNumbers proto <- corresponding Python proto +// ScatterDimensionNumbers proto <- corresponding Python proto // // Arrows indicate whether a conversion only ever occurs in one // direction, or whether it is maintained bidirectionally. @@ -167,8 +169,41 @@ bool HandleStringAttribute(PyObject* o, return true; // Handled string attribute, ok! } +bool HandleRepeatedInt64Attribute( + PyObject* o, const char* attr_name, + tensorflow::protobuf::RepeatedField* field) { + PyObject* seq = PyObject_GetAttrString(o, attr_name); + if (!seq) { + return false; + } + + int length = PySequence_Size(seq); + if (length == -1) { + Py_DECREF(seq); + return false; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(seq, i); + if (!item) { + Py_DECREF(seq); + return false; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(seq); + return false; + } + *field->Add() = dimension; + Py_DECREF(item); + } + Py_DECREF(seq); + return true; } -} + +} // namespace swig +} // namespace xla %} // Required to use PyArray_* functions. @@ -657,128 +692,27 @@ tensorflow::ImportNumpy(); %typemap(in) const DotDimensionNumbers& (DotDimensionNumbers dimension_numbers) { - int length; - - /* lhs_contracting_dimensions */ - PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( - $input, "lhs_contracting_dimensions"); - if (!lhs_contracting_dimensions) { - SWIG_fail; - } - - length = PySequence_Size(lhs_contracting_dimensions); - if (length == -1) { - Py_DECREF(lhs_contracting_dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); - if (!item) { - Py_DECREF(lhs_contracting_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(lhs_contracting_dimensions); - SWIG_fail; - } - dimension_numbers.add_lhs_contracting_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(lhs_contracting_dimensions); - - /* rhs_contracting_dimensions */ - PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( - $input, "rhs_contracting_dimensions"); - if (!lhs_contracting_dimensions) { + if (!HandleRepeatedInt64Attribute( + $input, "lhs_contracting_dimensions", + dimension_numbers.mutable_lhs_contracting_dimensions())) { SWIG_fail; } - - length = PySequence_Size(rhs_contracting_dimensions); - if (length == -1) { - Py_DECREF(rhs_contracting_dimensions); + if (!HandleRepeatedInt64Attribute( + $input, "rhs_contracting_dimensions", + dimension_numbers.mutable_rhs_contracting_dimensions())) { SWIG_fail; } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); - if (!item) { - Py_DECREF(rhs_contracting_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(rhs_contracting_dimensions); - SWIG_fail; - } - dimension_numbers.add_rhs_contracting_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(rhs_contracting_dimensions); - - /* lhs_batch_dimensions */ - PyObject* lhs_batch_dimensions = PyObject_GetAttrString( - $input, "lhs_batch_dimensions"); - if (!lhs_batch_dimensions) { + if (!HandleRepeatedInt64Attribute( + $input, "lhs_batch_dimensions", + dimension_numbers.mutable_lhs_batch_dimensions())) { SWIG_fail; } - - length = PySequence_Size(lhs_batch_dimensions); - if (length == -1) { - Py_DECREF(lhs_batch_dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); - if (!item) { - Py_DECREF(lhs_batch_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(lhs_batch_dimensions); - SWIG_fail; - } - dimension_numbers.add_lhs_batch_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(lhs_batch_dimensions); - - /* rhs_batch_dimensions */ - PyObject* rhs_batch_dimensions = PyObject_GetAttrString( - $input, "rhs_batch_dimensions"); - if (!rhs_batch_dimensions) { - SWIG_fail; - } - - length = PySequence_Size(rhs_batch_dimensions); - if (length == -1) { - Py_DECREF(rhs_batch_dimensions); + if (!HandleRepeatedInt64Attribute( + $input, "rhs_batch_dimensions", + dimension_numbers.mutable_rhs_batch_dimensions())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); - if (!item) { - Py_DECREF(rhs_batch_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(rhs_batch_dimensions); - SWIG_fail; - } - dimension_numbers.add_rhs_batch_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(rhs_batch_dimensions); - $1 = &dimension_numbers; } @@ -861,85 +795,80 @@ tensorflow::ImportNumpy(); dimension_numbers.set_kernel_input_feature_dimension(value); PyObject* o; - int length; - o = PyObject_GetAttrString($input, "input_spatial_dimensions"); - if (!o) { + if (!HandleRepeatedInt64Attribute( + $input, "input_spatial_dimensions", + dimension_numbers.mutable_input_spatial_dimensions())) { SWIG_fail; } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); + if (!HandleRepeatedInt64Attribute( + $input, "kernel_spatial_dimensions", + dimension_numbers.mutable_kernel_spatial_dimensions())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_input_spatial_dimensions(dimension); - Py_DECREF(item); + if (!HandleRepeatedInt64Attribute( + $input, "output_spatial_dimensions", + dimension_numbers.mutable_output_spatial_dimensions())) { + SWIG_fail; } - Py_DECREF(o); - o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); - if (!o) { + $1 = &dimension_numbers; +} + +// GatherDimensionNumbers + +%typemap(in) const GatherDimensionNumbers& + (GatherDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "offset_dims", + dimension_numbers.mutable_offset_dims())) { SWIG_fail; } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); + if (!HandleRepeatedInt64Attribute( + $input, "collapsed_slice_dims", + dimension_numbers.mutable_collapsed_slice_dims())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_kernel_spatial_dimensions(dimension); - Py_DECREF(item); + if (!HandleRepeatedInt64Attribute( + $input, "start_index_map", + dimension_numbers.mutable_start_index_map())) { + SWIG_fail; } - Py_DECREF(o); - o = PyObject_GetAttrString($input, "output_spatial_dimensions"); - if (!o) { + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { SWIG_fail; } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); + dimension_numbers.set_index_vector_dim(value); + + $1 = &dimension_numbers; +} + +// ScatterDimensionNumbers + +%typemap(in) const ScatterDimensionNumbers& + (ScatterDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "update_window_dims", + dimension_numbers.mutable_update_window_dims())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_output_spatial_dimensions(dimension); - Py_DECREF(item); + if (!HandleRepeatedInt64Attribute( + $input, "inserted_window_dims", + dimension_numbers.mutable_inserted_window_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "scatter_dims_to_operand_dims", + dimension_numbers.mutable_scatter_dims_to_operand_dims())) { + SWIG_fail; + } + + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { + SWIG_fail; } - Py_DECREF(o); + dimension_numbers.set_index_vector_dim(value); $1 = &dimension_numbers; } @@ -1151,6 +1080,8 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::QR; %unignore xla::swig::LocalComputationBuilder::TriangularSolve; %unignore xla::swig::LocalComputationBuilder::CustomCall; +%unignore xla::swig::LocalComputationBuilder::Gather; +%unignore xla::swig::LocalComputationBuilder::Scatter; %unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DestructureXrtAllocationTuple; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 378bbdcb175f10d73da87f5286cf5129477a124c..4e71121c097e1eca2d7fbd3299b17e06dd8b8e39 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1477,6 +1477,18 @@ class ComputationBuilder(object): return self._client.TriangularSolve( a, b, left_side, lower, transpose_a, conjugate_a) + def Gather(self, a, start_indices, dimension_numbers, slice_sizes): + """Enqueues a Gather operation onto the computation.""" + return self._client.Gather(a, start_indices, dimension_numbers, + slice_sizes) + + def Scatter(self, a, scatter_indices, updates, update_computation, + dimension_numbers): + """Enqueues a Scatter operation onto the computation.""" + return self._client.Scatter( + a, scatter_indices, updates, update_computation.computation, + dimension_numbers,) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 002a20e60a9fbe117af991731a555e60eef9397a..874e087eb6d4b785066edae21b1d11ebb024cd3e 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -1129,6 +1129,21 @@ class SingleOpTest(LocalComputationTest): self.assertFalse(c.IsConstant(non_const_expr)) # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564) + def testGather(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) + dnums = xla_client.xla_data_pb2.GatherDimensionNumbers() + dnums.offset_dims.append(1) + dnums.offset_dims.append(2) + dnums.start_index_map.append(0) + dnums.start_index_map.append(1) + dnums.index_vector_dim = 2 + c = self._NewComputation() + c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1]) + g = self._Execute(c, ()) + expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) + np.testing.assert_allclose(g, expected, rtol=1e-4) + class EmbeddedComputationsTest(LocalComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" @@ -1186,6 +1201,14 @@ class EmbeddedComputationsTest(LocalComputationTest): c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) return c.Build() + def _CreateBinaryAddS32Computation(self): + """Computation (s32, s32) -> s32 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayS32(0)), + c.ParameterFromNumpy(NumpyArrayS32(0))) + return c.Build() + def _CreateBinaryAddF32Computation(self): """Computation (f32, f32) -> f32 that adds its two parameters.""" c = self._NewComputation("add_param0_by_param1") @@ -1568,6 +1591,23 @@ class EmbeddedComputationsTest(LocalComputationTest): execution.join() self.assertEqual(want, got) + def testScatter(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + scatter_indices = np.array([0, 2], dtype=np.int32) + updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) + + dnums = xla_client.xla_data_pb2.ScatterDimensionNumbers() + dnums.update_window_dims.append(1) + dnums.inserted_window_dims.append(0) + dnums.scatter_dims_to_operand_dims.append(0) + dnums.index_vector_dim = 1 + + c = self._NewComputation() + c.Scatter(c.Constant(a), c.Constant(scatter_indices), c.Constant(updates), + self._CreateBinaryAddS32Computation(), dnums) + expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32) + self._ExecuteAndCompareClose(c, expected=expected) + class ErrorTest(LocalComputationTest): diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 3ba67f69e8cd1f230095d047ec9075f072e26fd9..08b78ee244844f41d551d7e249cec0cbf157d639 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -550,7 +551,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( HloEvaluator evaluator; Literal result_literal = - evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); + evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); CHECK_EQ(result_literal.shape().rank(), 4); auto result = @@ -605,24 +606,26 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& reduce_function) { std::vector result; CHECK_EQ(dims.size(), 3); - const std::set dim_set(dims.begin(), dims.end()); + const absl::flat_hash_set dim_set(dims.begin(), dims.end()); CHECK_EQ(dim_set.size(), 3); - for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) { - for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2()); + for (int64 a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1()); + ++a0) { + for (int64 a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2()); ++a1) { - for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3()); + for (int64 a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3()); ++a2) { - for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4()); + for (int64 a3 = 0; a3 == 0 || (!dim_set.contains(3) && a3 < array.n4()); ++a3) { float accumulator = init; - for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1()); - ++i0) { - for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2()); - ++i1) { + for (int64 i0 = 0; + i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) { + for (int64 i1 = 0; + i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) { for (int64 i2 = 0; - i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { + i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) { for (int64 i3 = 0; - i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { + i3 == 0 || (dim_set.contains(3) && i3 < array.n4()); + ++i3) { // Handle zero-sized arrays. if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 && array.n4() > 0) { diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index d8736c819687482a9dead57bdeacff8e75dce105..728adb67a0d64c5dcb7a6e733b5272f6db911419 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1,6 +1,14 @@ # Description: # XLA service implementation. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") + licenses(["notice"]) # Apache 2.0 package(default_visibility = [":friends"]) @@ -12,15 +20,6 @@ package_group( ], ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_proto_library_py", -) - xla_proto_library( name = "hlo_proto", srcs = ["hlo.proto"], @@ -237,6 +236,7 @@ cc_library( ], hdrs = ["hlo_evaluator.h"], deps = [ + ":dynamic_dimension_inference", ":hlo", ":hlo_casting_utils", ":hlo_query", @@ -516,6 +516,7 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -678,6 +679,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -696,6 +698,7 @@ cc_library( ":compiler", ":computation_layout", ":device_memory_allocator", + ":dynamic_dimension_inference", ":executable", ":execution_tracker", ":hlo", @@ -1003,6 +1006,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1137,6 +1141,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1580,6 +1585,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1866,8 +1873,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -1931,6 +1939,46 @@ cc_library( ], ) +cc_library( + name = "dynamic_padder", + srcs = ["dynamic_padder.cc"], + hdrs = ["dynamic_padder.h"], + deps = [ + ":dynamic_dimension_inference", + ":hlo_dce", + ":hlo_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "dynamic_padder_test", + srcs = ["dynamic_padder_test.cc"], + deps = [ + ":dynamic_padder", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "dynamic_dimension_inference_test", srcs = ["dynamic_dimension_inference_test.cc"], @@ -2116,6 +2164,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2288,6 +2337,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -2548,6 +2598,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2592,6 +2643,7 @@ tf_cc_test( srcs = ["hlo_verifier_test.cc"], deps = [ ":hlo", + ":hlo_module_config", ":hlo_parser", ":hlo_verifier", ":layout_assignment", @@ -2599,6 +2651,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -2969,6 +3022,7 @@ cc_library( srcs = ["hlo_get_dimension_size_rewriter.cc"], hdrs = ["hlo_get_dimension_size_rewriter.h"], deps = [ + ":dynamic_dimension_inference", ":hlo", ":hlo_pass", ":shape_inference", @@ -3186,6 +3240,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", @@ -3403,6 +3458,7 @@ cc_library( ":hlo_profile_printer_data", ":human_readable_profile_builder", "//tensorflow/compiler/xla:types", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], ) @@ -3574,6 +3630,7 @@ cc_library( tf_cc_test( name = "indexed_array_analysis_test", srcs = ["indexed_array_analysis_test.cc"], + extra_copts = ["-Wno-string-plus-int"], deps = [ ":hlo_matchers", ":indexed_array_analysis", @@ -3675,6 +3732,7 @@ cc_library( ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -3686,6 +3744,38 @@ cc_library( ], ) +cc_library( + name = "dynamic_index_splitter", + srcs = ["dynamic_index_splitter.cc"], + hdrs = ["dynamic_index_splitter.h"], + deps = [ + ":hlo_casting_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "dynamic_index_splitter_test", + srcs = ["dynamic_index_splitter_test.cc"], + deps = [ + ":dynamic_index_splitter", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "ar_crs_combiner_test", srcs = ["ar_crs_combiner_test.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index cad70a8d10f4e7bed22b11039309d5c7d02e650f..5ac746c9f3fc0503c143da2f3da8b312ae0a4280 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -26,6 +26,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -369,6 +371,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Tries to convert slice(reshape(X)) into reshape(slice(X)) StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); + // If the sort instruction has a tuple shape then looks for unused output + // values and removes them from the sort instruction. Returns true if the + // graph have been modified. + StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -1373,6 +1380,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}. bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice; + HloDynamicSliceInstruction* dynamic_slice = + lhs_is_dynamic_slice ? Cast(lhs) + : Cast(rhs); // ctA: HloInstruction* left_operand = @@ -1390,8 +1400,6 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, dnums, dot->precision_config())); // Get pair {start, 0} or {0, start}. - HloInstruction* original_start_indices = - lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); // Position of start: int index_of_non_zero_start = lhs_is_dynamic_slice ? 1 - lhs_contracting_dimension @@ -1400,23 +1408,19 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( int index_of_zero_start = 1 - index_of_non_zero_start; // Slice out start and 0 components and reorder if necessary. - auto indices_type = original_start_indices->shape().element_type(); + auto indices_type = dynamic_slice->operand(1)->shape().element_type(); Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); HloInstruction* non_zero_start = - computation_->AddInstruction(HloInstruction::CreateSlice( - s_shape, original_start_indices, {index_of_non_zero_start}, - {index_of_non_zero_start + 1}, {1})); + dynamic_slice->mutable_operand(1 + index_of_non_zero_start); HloInstruction* zero_start = - computation_->AddInstruction(HloInstruction::CreateSlice( - s_shape, original_start_indices, {index_of_zero_start}, - {index_of_zero_start + 1}, {1})); - HloInstruction* new_start_indices = - lhs_is_dynamic_slice - ? computation_->AddInstruction(HloInstruction::CreateConcatenate( - d_shape, {non_zero_start, zero_start}, 0)) - : computation_->AddInstruction(HloInstruction::CreateConcatenate( - d_shape, {zero_start, non_zero_start}, 0)); + dynamic_slice->mutable_operand(1 + index_of_zero_start); + std::vector new_start_indices; + if (lhs_is_dynamic_slice) { + new_start_indices = {non_zero_start, zero_start}; + } else { + new_start_indices = {zero_start, non_zero_start}; + } // Build DynamicSlice(ctA x ctB). const int new_slice_m = lhs_is_dynamic_slice ? 1 : m; @@ -2209,8 +2213,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { auto dim_is_one = [&](int64 i) -> bool { return reverse->shape().dimensions(i) == 1; }; - if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(), - dim_is_one)) { + if (absl::c_all_of(reverse->dimensions(), dim_is_one)) { return ReplaceInstruction(reverse, reverse->mutable_operand(0)); } return Status::OK(); @@ -2485,9 +2488,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // Create a new reduce with the combined reduction dimensions of both // reduces. std::vector arg_dims = arg->dimensions(); - std::sort(arg_dims.begin(), arg_dims.end()); + absl::c_sort(arg_dims); std::vector reduce_dims = reduce->dimensions(); - std::sort(reduce_dims.begin(), reduce_dims.end()); + absl::c_sort(reduce_dims); // Transform reduce_dims to the same rank as the operand of the operand. for (int64 arg_dim : arg_dims) { for (int64& dim : reduce_dims) { @@ -2533,7 +2536,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { } if (can_move_reshape_into_reduce) { changed_ = true; - std::unordered_set dimensions_not_to_reduce; + absl::flat_hash_set dimensions_not_to_reduce; for (auto dim_pair : unmodified_dims) { if (arg_dim_in_output[dim_pair.second]) { dimensions_not_to_reduce.insert(dim_pair.first); @@ -2541,7 +2544,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { } std::vector new_reduce_dimensions; for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) { - if (dimensions_not_to_reduce.count(i) == 0) { + if (!dimensions_not_to_reduce.contains(i)) { new_reduce_dimensions.push_back(i); } } @@ -2595,51 +2598,53 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( function)); } - // A reduce window can be expressed as a reduce and a reshape if all - // dimensions either have a window size of one or the entire dimension. If - // there is no stride, dilation, or padding, this is as easy as checking the - // size of the output shape and window dimension. - // - // The reshape is a bitcast since it adds one-sized dimensions. Often these - // ones are immediately removed as well with another reshape. The - // implementation of reduce tends to be slightly more efficient at reducing - // entire dimensions compared to reduce window. - auto effective_reduce_dims = [&] { - if (window_util::HasStride(window) || window_util::HasDilation(window) || - window_util::HasPadding(window)) { - return absl::InlinedVector{}; - } - absl::InlinedVector reduce_dims; - for (int64 i = 0; i < window.dimensions_size(); ++i) { - if (window.dimensions(i).size() == 1) { - continue; - } else if (reduce_window->shape().dimensions(i) == 1) { - reduce_dims.push_back(i); - } else { + if (options_.enable_window_reduce_to_reduce_replacement()) { + // A reduce window can be expressed as a reduce and a reshape if all + // dimensions either have a window size of one or the entire dimension. If + // there is no stride, dilation, or padding, this is as easy as checking the + // size of the output shape and window dimension. + // + // The reshape is a bitcast since it adds one-sized dimensions. Often these + // ones are immediately removed as well with another reshape. The + // implementation of reduce tends to be slightly more efficient at reducing + // entire dimensions compared to reduce window. + auto effective_reduce_dims = [&] { + if (window_util::HasStride(window) || window_util::HasDilation(window) || + window_util::HasPadding(window)) { return absl::InlinedVector{}; } - } - return reduce_dims; - }(); + absl::InlinedVector reduce_dims; + for (int64 i = 0; i < window.dimensions_size(); ++i) { + if (window.dimensions(i).size() == 1) { + continue; + } else if (reduce_window->shape().dimensions(i) == 1) { + reduce_dims.push_back(i); + } else { + return absl::InlinedVector{}; + } + } + return reduce_dims; + }(); - // If a reduce window can be expressed as a reduce, do so and reshape the - // output. - if (!effective_reduce_dims.empty()) { - Shape reduce_shape = ShapeUtil::FilterDimensions( - [&](int64 dim) { - return !absl::c_linear_search(effective_reduce_dims, dim); - }, - reduce_window->shape()); - HloInstruction* reduce = - computation_->AddInstruction(HloInstruction::CreateReduce( - /*shape=*/reduce_shape, - /*operand=*/operand, - /*init_value=*/reduce_window->mutable_operand(1), - /*dimensions_to_reduce=*/effective_reduce_dims, - /*reduce_computation=*/function)); - return ReplaceWithNewInstruction( - reduce_window, - HloInstruction::CreateReshape(reduce_window->shape(), reduce)); + // If a reduce window can be expressed as a reduce, do so and reshape the + // output. + if (!effective_reduce_dims.empty()) { + Shape reduce_shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { + return !absl::c_linear_search(effective_reduce_dims, dim); + }, + reduce_window->shape()); + HloInstruction* reduce = + computation_->AddInstruction(HloInstruction::CreateReduce( + /*shape=*/reduce_shape, + /*operand=*/operand, + /*init_value=*/reduce_window->mutable_operand(1), + /*dimensions_to_reduce=*/effective_reduce_dims, + /*reduce_computation=*/function)); + return ReplaceWithNewInstruction( + reduce_window, + HloInstruction::CreateReshape(reduce_window->shape(), reduce)); + } } // This optimization folds a pad op into reduce_window. @@ -2814,6 +2819,69 @@ Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { return Status::OK(); } +StatusOr AlgebraicSimplifierVisitor::RemoveUnusedOperandFromSort( + HloInstruction* sort) { + if (!sort->shape().IsTuple()) { + return false; + } + + if (sort->parent()->root_instruction() == sort) { + // Can't analyse users of the root instruction. + return false; + } + + // Index 0 is the sorting key used by the sort HLO itself. + absl::flat_hash_set used_indices{0}; + for (const HloInstruction* user : sort->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + // Can't analyse users other then get-tuple-element. + return false; + } + used_indices.insert(user->tuple_index()); + } + + if (used_indices.size() == sort->operand_count()) { + // All operands are used. + return false; + } + + std::vector operands{sort->mutable_operand(0)}; + std::vector new_shapes{sort->operand(0)->shape()}; + for (int64 i = 1; i < sort->operand_count(); ++i) { + if (used_indices.count(i)) { + operands.push_back(sort->mutable_operand(i)); + new_shapes.push_back(sort->operand(i)->shape()); + } + } + Shape new_sort_shape = new_shapes.size() == 1 + ? new_shapes[0] + : ShapeUtil::MakeTupleShape(new_shapes); + HloInstruction* new_sort = computation_->AddInstruction( + sort->CloneWithNewOperands(new_sort_shape, operands)); + + // Map from original get-tuple-element tuple index to new HLO instruction + absl::flat_hash_map result_map; + if (new_sort->shape().IsTuple()) { + // Old sort key maps to new sort key. + int64 new_index = 0; + for (int64 i = 0; i < sort->operand_count(); ++i) { + if (used_indices.count(i)) { + result_map[i] = + computation_->AddInstruction(HloInstruction::CreateGetTupleElement( + new_shapes[new_index], new_sort, new_index)); + ++new_index; + } + } + } else { + result_map[0] = new_sort; + } + for (HloInstruction* user : sort->users()) { + TF_RETURN_IF_ERROR( + user->ReplaceAllUsesWith(result_map.at(user->tuple_index()))); + } + return true; +} + Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { auto operand = sort->mutable_operand(0); int64 dimension_to_sort = sort->dimensions(0); @@ -2826,6 +2894,14 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return ReplaceWithNewInstruction( sort, HloInstruction::CreateTuple(sort->operands())); } + + // Remove the unused values from a key-value sort. + TF_ASSIGN_OR_RETURN(bool removed_operand, RemoveUnusedOperandFromSort(sort)); + if (removed_operand) { + changed_ = true; + return Status::OK(); + } + if (!options_.enable_permutation_sort_replacement()) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index d2775b9fafa7e4c625f5d181114e80e7369f9c78..1acaadaaa03803e467fd9bf6b22f5793fd6a1ec9 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -75,12 +75,24 @@ class AlgebraicSimplifierOptions { return enable_permutation_sort_replacement_; } + // If enable_window_reduce_replacement is true, the kReduceWindow instruction + // can be optimized by replacement with simpler operations. + void set_enable_window_reduce_to_reduce_replacement( + bool enable_window_reduce_to_reduce_replacement) { + enable_window_reduce_to_reduce_replacement_ = + enable_window_reduce_to_reduce_replacement; + } + bool enable_window_reduce_to_reduce_replacement() const { + return enable_window_reduce_to_reduce_replacement_; + } + private: ValidBitcastCallback valid_bitcast_callback_; bool is_layout_sensitive_{false}; bool enable_dot_strength_reduction_{true}; bool enable_conv_simplification_{true}; bool enable_permutation_sort_replacement_{false}; + bool enable_window_reduce_to_reduce_replacement_{true}; }; // A pass which performs algebraic simplifications. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 51ad748ff829966d04e360d457f2ae081ec187eb..916b953cfc25693b1f75095e28b3167dc8c740e1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2709,6 +2709,74 @@ TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortWrongDimensions) { EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } +TEST_F(AlgebraicSimplifierTest, RemoveUnusedSortOperandArrayResult) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1} + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Sort(m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, RemoveUnusedSortOperandTuple) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,87] parameter(0) + values.0 = s32[64,87] parameter(1) + values.1 = u32[64,87] parameter(2) + sort = (f32[64,87], s32[64,87], u32[64,87]) sort( + keys, values.0, values.1), + dimensions={1} + gte.0 = f32[64,87] get-tuple-element(sort), index=0 + gte.1 = u32[64,87] get-tuple-element(sort), index=2 + ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + GmockMatch(m::Tuple( + m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 0), + m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 1)))); +} + +TEST_F(AlgebraicSimplifierTest, DontRemoveUnusedSortKey) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} + ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { auto builder = HloComputation::Builder(TestName()); @@ -3706,12 +3774,16 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); + std::vector params; + for (int i = 0; i < 3; ++i) { + params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( + i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + } builder.AddInstruction(HloInstruction::CreateDynamicSlice( shape, builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + params, /*slice_sizes=*/{10, 100, 1000})); auto computation = m->AddEntryComputation(builder.Build()); @@ -3730,28 +3802,35 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); + std::vector slice_indices, update_indices; + for (int i = 0; i < 3; ++i) { + slice_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + update_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + i + 5, ShapeUtil::MakeShape(U32, {}), "update_indices"))); + } HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( slice_shape, builder.AddInstruction( HloInstruction::CreateParameter(0, full_shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + slice_indices, /*slice_sizes=*/{10, 1, 1000})); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( slice_shape, builder.AddInstruction( - HloInstruction::CreateParameter(2, slice_shape, "to_update")), - slice, - builder.AddInstruction(HloInstruction::CreateParameter( - 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); + HloInstruction::CreateParameter(4, slice_shape, "to_update")), + slice, update_indices)); auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter()))); + GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter(), + m::Parameter(), m::Parameter()))); } // Test that two consecutive broadcasts can be merged to one. @@ -4412,9 +4491,10 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { HloInstruction* const update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); HloInstruction* const start_indices = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0({}))); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - dslice_shape, operand, update, start_indices)); + dslice_shape, operand, update, + std::initializer_list({start_indices}))); const HloComputation* const computation = m->AddEntryComputation(builder.Build()); @@ -4467,14 +4547,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int32 start_row = (spec.lcd == 0) ? 0 : spec.s; int32 start_col = (spec.lcd == 0) ? spec.s : 0; - const auto start_indices = + std::vector start_indices = { + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(start_row))), builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR0(start_col)))}; int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; - Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + std::vector slice_sizes = {slice_row_size, slice_col_size}; + Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes); auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ds_shape, lhs, start_indices, {slice_row_size, slice_col_size})); + ds_shape, lhs, start_indices, slice_sizes)); int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n; int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; @@ -4507,7 +4590,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { } else { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), - m::Concatenate()))); + m::Constant(), m::Constant()))); } } @@ -4545,14 +4628,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int32 start_row = (spec.rcd == 0) ? 0 : spec.s; int32 start_col = (spec.rcd == 0) ? spec.s : 0; - const auto start_indices = + std::vector start_indices = { + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(start_row))), builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR0(start_col)))}; int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; - Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + std::vector slice_sizes = {slice_row_size, slice_col_size}; + Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes); auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ds_shape, rhs, start_indices, {slice_row_size, slice_col_size})); + ds_shape, rhs, start_indices, slice_sizes)); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(spec.lcd); @@ -4577,7 +4663,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { } else { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), - m::Concatenate()))); + m::Constant(), m::Constant()))); } } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index 47d2c7e35705698d49950c2fa042af1c6327d521..35a32f537b81fcd4bcbda4aaf8f56ea900559905 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -44,11 +45,24 @@ bool MatchesArCrsPattern(HloInstruction* instruction) { if (instruction->user_count() != 1) { return false; } - auto opcode = instruction->opcode(); - return opcode == HloOpcode::kBitcast || opcode == HloOpcode::kTranspose || - opcode == HloOpcode::kReshape || opcode == HloOpcode::kConvert || - opcode == HloOpcode::kAdd || opcode == HloOpcode::kSubtract || - opcode == HloOpcode::kMultiply; + switch (instruction->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kReshape: + return true; + case HloOpcode::kConvert: + // Can be moved across if both input and output is either float or + // integer (e.g. S32<->U32 or F32<->BF16) + return ShapeUtil::ElementIsFloating(instruction->shape()) == + ShapeUtil::ElementIsFloating(instruction->operand(0)->shape()); + case HloOpcode::kAdd: + case HloOpcode::kSubtract: + case HloOpcode::kMultiply: + // Only supported for floating point operands. + return ShapeUtil::ElementIsFloating(instruction->shape()); + default: + return false; + } }; auto computation_is_addition = [](HloComputation* c) { @@ -176,6 +190,15 @@ bool ArCrsCombiner::InstructionsComputeSameValue( if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) { return false; } + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + if (i1->IsCrossModuleAllReduce()) { + return i1->Identical(*i2, + /*eq_operands=*/std::equal_to(), + eq_computations, + /*layout_sensitive=*/false); + } visited_pairs->emplace(min_uid, max_uid); for (int i = 0; i < operands1.size(); ++i) { auto operand1 = operands1[i]; @@ -201,9 +224,6 @@ bool ArCrsCombiner::InstructionsComputeSameValue( // InstructionsComputeSameValue earlier. auto eq_instructions = [](const HloInstruction* i1, const HloInstruction* i2) -> bool { return true; }; - auto eq_computations = [](const HloComputation* a, const HloComputation* b) { - return *a == *b; - }; return i1->Identical(*i2, eq_instructions, eq_computations, /*layout_sensitive=*/false); } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index caa57296f465698eb70d7cb8327d4678f394b323..b12b63b2dd779579307c1c4d4bf7860f5955af4d 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -360,6 +360,7 @@ HloModule foobar ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { %p = bf16[] parameter(0) + %constant.bf16 = bf16[] constant(1) %all-reduce.ar.1 = bf16[] all-reduce(%p), @@ -377,7 +378,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { sharding={maximal device=0} %all-reduce.ar.2 = bf16[] - all-reduce(%p), + all-reduce(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, to_apply=%sum.bf16, @@ -407,7 +408,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Tuple(op::AllReduce(op::Convert(op::Parameter())), - op::AllReduce(op::Convert(op::Parameter())))); + op::AllReduce(op::Convert(op::Constant())))); auto crs_after = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 2cf24a9dd5fa18abe9dde4eb49b03c6586bfef03..215e8ced4bb3f98a26ac4eb9912a7fd4d917852f 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -115,12 +115,10 @@ StatusOr Backend::BorrowStream(int device_ordinal) { StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(mu_); - if (0 == stream_pools_.count(executor)) { - stream_pools_.emplace(std::piecewise_construct, - std::forward_as_tuple(executor), - std::forward_as_tuple()); + if (!stream_pools_.contains(executor)) { + stream_pools_.emplace(executor, absl::make_unique()); } - return stream_pools_.at(executor).BorrowStream(executor); + return stream_pools_.at(executor)->BorrowStream(executor); } Backend::Backend(se::Platform* platform, Compiler* compiler, diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 7ca993fb2656037951d98d9c4459a3c3e4c64c61..c35f033dc0180409ae3888c2050021da83f5c72a 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -175,7 +176,8 @@ class Backend { tensorflow::mutex mu_; // Mapping from stream executor to stream pools, used by `BorrowStream` above. - std::map stream_pools_ GUARDED_BY(mu_); + absl::flat_hash_map> + stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 6f4c1104f363146cebea83873b936ea3f8292801..d07615b828990f80e2f905837c46f5f2e15d5a63 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -86,10 +86,9 @@ std::vector ColorInterferenceGraph( // first, but it would be good to investigate other ordering heuristics too. std::vector nodes(node_count); std::iota(nodes.begin(), nodes.end(), 0); - std::sort(nodes.begin(), nodes.end(), - [&interference_map](const int64 i, const int64 j) { - return interference_map[i].size() > interference_map[j].size(); - }); + absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); const int64 kColorUnassigned = -1; std::vector assigned_colors(node_count, kColorUnassigned); @@ -138,8 +137,8 @@ Status GatherComputationsByAllocationType( worklist.pop_front(); const HloComputation* computation = worklist_front.first; bool is_thread_local = worklist_front.second; - bool in_thread_local_set = thread_local_set.count(computation) > 0; - bool in_global_set = global_set.count(computation) > 0; + bool in_thread_local_set = thread_local_set.contains(computation); + bool in_global_set = global_set.contains(computation); // If the computation has already been added to the respective set, then // nothing to do. @@ -207,9 +206,9 @@ Status GatherComputationsByAllocationType( // Add the computations to the vectors in post order. for (auto* computation : module->MakeComputationPostOrder()) { - if (thread_local_set.count(computation) > 0) { + if (thread_local_set.contains(computation)) { thread_local_computations->push_back(computation); - } else if (global_set.count(computation) > 0) { + } else if (global_set.contains(computation)) { global_computations->push_back(computation); } // If the computation is not reachable from the entry computation, then it @@ -219,13 +218,6 @@ Status GatherComputationsByAllocationType( return Status::OK(); } -size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { - uint64 h = std::hash()(s.index()); - h = tensorflow::Hash64Combine(h, std::hash()(s.offset())); - h = tensorflow::Hash64Combine(h, std::hash()(s.size())); - return h; -} - string BufferAllocation::Slice::ToString() const { return absl::StrCat("{index:", index(), ", offset:", offset_, ", size:", size_, "}"); @@ -240,7 +232,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size) { VLOG(4) << "Trying to add " << buffer << " to allocation #" << index(); - CHECK(assigned_buffers_.count(&buffer) == 0) + CHECK(!assigned_buffers_.contains(&buffer)) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; CHECK_LE(offset, size_) << "LogicalBuffer " << buffer @@ -279,11 +271,12 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto_assigned->set_offset(buffer_offset_size.second.offset); proto_assigned->set_size(buffer_offset_size.second.size); } - std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(), - [](const BufferAllocationProto::Assigned& assign1, - const BufferAllocationProto::Assigned& assign2) { - return assign1.logical_buffer_id() < assign2.logical_buffer_id(); - }); + absl::c_sort(*proto.mutable_assigned(), + [](const BufferAllocationProto::Assigned& assign1, + const BufferAllocationProto::Assigned& assign2) { + return assign1.logical_buffer_id() < + assign2.logical_buffer_id(); + }); return proto; } @@ -315,10 +308,10 @@ string BufferAllocation::ToString() const { for (const auto& buffer_offset_size : assigned_buffers_) { sorted_buffers.push_back(buffer_offset_size.first); } - std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); + absl::c_sort(sorted_buffers, + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); StrAppend(&output, absl::StrFormat( @@ -346,7 +339,7 @@ const PointsToSet& BufferAssignment::GetPointsToSet( bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const { TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); - return allocation_index_for_buffer_.count(&buffer) > 0; + return allocation_index_for_buffer_.contains(&buffer); } const BufferAllocation& BufferAssignment::GetAssignedAllocation( @@ -401,7 +394,7 @@ bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction, const ShapeIndex& index) const { for (const LogicalBuffer* buffer : GetPointsToSet(instruction).element(index)) { - if (allocation_index_for_buffer_.count(buffer) > 0) { + if (allocation_index_for_buffer_.contains(buffer)) { return true; } } @@ -459,8 +452,7 @@ bool BufferAssignment::SharesSliceAtIndex( bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, const HloInstruction* hlo_b) const { - using SliceSet = - flat_hash_set; + using SliceSet = flat_hash_set; // Gets the slices all of instr's subshapes. If any subshape doesn't have an // assigned slice, returns the empty set. auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { @@ -487,10 +479,9 @@ bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, // didn't return the empty set) for both HLOs, and the two resulting sets of // slices are disjoint. return !slices_a.empty() && !slices_b.empty() && - std::none_of(slices_a.begin(), slices_a.end(), - [&](const BufferAllocation::Slice& slice) { - return slices_b.count(slice) > 0; - }); + absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) { + return slices_b.contains(slice); + }); } StatusOr @@ -519,7 +510,7 @@ BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, void BufferAssignment::AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, int64 offset, int64 size) { - CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer)) + CHECK(!allocation_index_for_buffer_.contains(&buffer)) << "LogicalBuffer " << buffer << " already has an allocation."; CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty()) << "Non-reusable allocation already assigned a buffer: " @@ -960,35 +951,35 @@ Status BufferAssigner::AssignBuffersForComputation( // operands (assuming operands are the same/larger size) enabling the // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. - std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [has_sequential_order, &liveness, &post_order_position, assignment]( - const LogicalBuffer* a, const LogicalBuffer* b) { - // Primary sort is by decreasing buffer size. - const int64 a_size = assignment->buffer_size_(*a); - const int64 b_size = assignment->buffer_size_(*b); - if (a_size != b_size) { - return a_size > b_size; // use ">" for decreasing size. - } - // Otherwise live out buffers come before others, if the - // instructions are sequentially ordered. - if (has_sequential_order) { - const bool a_live_out = liveness.MaybeLiveOut(*a); - const bool b_live_out = liveness.MaybeLiveOut(*b); - if (a_live_out != b_live_out) { - return a_live_out; - } - } - // Final tiebreaker is in instruction post order. - return post_order_position.at(a->instruction()) < - post_order_position.at(b->instruction()); - }); + absl::c_sort(sorted_buffers, + [has_sequential_order, &liveness, &post_order_position, + assignment](const LogicalBuffer* a, const LogicalBuffer* b) { + // Primary sort is by decreasing buffer size. + const int64 a_size = assignment->buffer_size_(*a); + const int64 b_size = assignment->buffer_size_(*b); + if (a_size != b_size) { + return a_size > b_size; // use ">" for decreasing size. + } + // Otherwise live out buffers come before others, if the + // instructions are sequentially ordered. + if (has_sequential_order) { + const bool a_live_out = liveness.MaybeLiveOut(*a); + const bool b_live_out = liveness.MaybeLiveOut(*b); + if (a_live_out != b_live_out) { + return a_live_out; + } + } + // Final tiebreaker is in instruction post order. + return post_order_position.at(a->instruction()) < + post_order_position.at(b->instruction()); + }); // BufferAllocations are necessarily created in decreasing size order. Keep // indices of previously created BufferAllocations in allocation_indices. std::vector allocation_indices; for (const LogicalBuffer* buffer : sorted_buffers) { VLOG(3) << "Assigning allocation to: " << *buffer; - if (colocated_buffers.count(buffer) > 0) { + if (colocated_buffers.contains(buffer)) { // Colocated buffers are currently assigned in an earlier pass. VLOG(3) << "Skipping colocated buffer: " << *buffer; continue; @@ -1056,7 +1047,7 @@ Status BufferAssigner::AssignBuffersForComputation( assignment->GetAllSlices(operand, /*index=*/{})) { BufferAllocation* allocation = assignment->GetMutableAllocation(operand_slice.index()); - if (colocated_allocations.count(allocation->index()) == 0) { + if (!colocated_allocations.contains(allocation->index())) { // TODO(b/32491382) Colocated buffers are currently assigned in an // earlier pass, and so can break the "increasing allocation size" // invariant in this function (causing this CHECK to fail). However, @@ -1087,7 +1078,7 @@ Status BufferAssigner::AssignBuffersForComputation( // Instructions are iterated in increasing buffer size, so any // previously create allocation must be large enough to hold this // instruction's output (with the exception of colocated buffers). - if (colocated_allocations.count(allocation->index()) == 0) { + if (!colocated_allocations.contains(allocation->index())) { // TODO(b/32491382) Colocated buffers are currently assigned in an // earlier pass, and so can break the "increasing allocation size" // invariant in this function (causing this CHECK to fail). However, @@ -1313,10 +1304,10 @@ std::vector ComputePeakMemoryLogicalBuffers( live_buffers.end()); // Stabily sort the live buffers. - std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); + absl::c_sort(live_buffers_vector, + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); return live_buffers_vector; } @@ -1376,7 +1367,7 @@ void BufferAssigner::AddSetToColocatedBufferSets( std::vector overlap_set_indices; for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) { for (const LogicalBuffer* buffer : colocated_set) { - if ((*colocated_buffer_sets)[index].count(buffer) > 0) { + if ((*colocated_buffer_sets)[index].contains(buffer)) { VLOG(5) << "Found overlap with existing set on buffer " << buffer->ToString() << "\n" << ColocatedBufferSetsToString((*colocated_buffer_sets)[index], @@ -1539,15 +1530,16 @@ void BufferAssigner::BuildColocatedBufferSets( VLOG(4) << "Input/Output Alias Config: "; VLOG(4) << module->input_output_alias_config(); module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { std::vector colocated_set; AddBufferToColocatedSet(module->entry_computation()->root_instruction(), output_index, points_to_analysis, &colocated_set); AddBufferToColocatedSet( - module->entry_computation()->parameter_instruction(param_number), - param_index, points_to_analysis, &colocated_set); + module->entry_computation()->parameter_instruction( + alias.parameter_number), + alias.parameter_index, points_to_analysis, &colocated_set); AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 0a9fdede803e84ca42472259084615c031b206eb..4baab9b6ad71293d48d5ed70c2922fdf40ef119a 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -186,9 +186,10 @@ class BufferAllocation { end > other.offset_; } - struct Hasher { - size_t operator()(Slice s) const; - }; + template + friend H AbslHashValue(H h, const Slice& s) { + return H::combine(std::move(h), s.index(), s.offset(), s.size()); + } string ToString() const; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8f482e6ba8c3e71c9980be5e6947ea61f3b4ef29..1b4e93a2f303e5aad3e4081f36e2417277f62c71 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -309,7 +310,7 @@ class BufferAssignmentTest : public HloTestBase { static bool BuffersDistinct(const std::vector& a, const std::vector& b, const BufferAssignment& assignment) { - std::set a_slices; + absl::flat_hash_set a_slices; for (const HloInstruction* instruction : a) { if (assignment.HasTopLevelAllocation(instruction)) { a_slices.insert( @@ -319,8 +320,8 @@ static bool BuffersDistinct(const std::vector& a, for (const HloInstruction* instruction : b) { if (assignment.HasTopLevelAllocation(instruction)) { - if (a_slices.count(assignment.GetUniqueTopLevelSlice(instruction) - .ConsumeValueOrDie())) { + if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction) + .ConsumeValueOrDie())) { return false; } } @@ -2485,9 +2486,9 @@ while_body { get-tuple-element.3 = s32[] get-tuple-element(state), index=0 constant.2 = s32[] constant(128) add.5 = s32[] add(get-tuple-element.3, constant.2) - constant.3 = s32[3]{0} constant({0, 0, 0}) - dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3) - dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3) + constant.3 = s32[] constant(0) + dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9) } diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 79241342005d3dc59051ea638798062ff56f11a9..5de724f8924b78008ba4c56603b61bf93fbc5e7c 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -384,7 +384,7 @@ TEST_F(CallGraphTest, ComplexGraph) { // Verify visitation order of some computations in the graph. auto index_of = [&visited](const HloComputation* comp) { - auto it = std::find(visited.begin(), visited.end(), comp); + auto it = absl::c_find(visited, comp); EXPECT_NE(it, visited.end()); return std::distance(visited.begin(), it); }; diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 3c2d1ae6d82ebc6c10d52194fd1cec5e291025f7..b517495f2ea0c75679685c67f757ff586f8c79e3 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -72,7 +72,7 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) { } Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { - if (opaque_to_channel_.count(handle.handle()) == 0) { + if (!opaque_to_channel_.contains(handle.handle())) { return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; @@ -94,7 +94,7 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { } Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { - if (opaque_to_channel_.count(handle.handle()) == 0) { + if (!opaque_to_channel_.contains(handle.handle())) { return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index 52037bf9b52556c6aa2e66dd3209e25cf085cfe3..89e17eba36f23077ce4cf0704e7455b76bee68d1 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status.h" @@ -83,7 +84,8 @@ class ChannelTracker { // Mapping from ChannelHandle value to the corresponding registered // Channel object. - std::map opaque_to_channel_ GUARDED_BY(channel_mutex_); + absl::flat_hash_map opaque_to_channel_ + GUARDED_BY(channel_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ChannelTracker); }; diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 8f08c244908efb823b3870c19bdc3491fa87d44f..653f4555a77cc82e91fb1cd26206b93826375732 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -98,10 +98,17 @@ Compiler::GetPlatformCompilers() { auto* factories = GetPlatformCompilerFactories(); auto it = factories->find(platform->id()); if (it == factories->end()) { + string hint; + if (platform->Name() == "Host") { + hint = " (hint: try linking in tensorflow/compiler/jit:xla_cpu_jit)"; + } else if (platform->Name() == "CUDA") { + hint = " (hint: try linking in tensorflow/compiler/jit:xla_gpu_jit)"; + } + return NotFound( "could not find registered compiler for platform %s -- check " - "target linkage", - platform->Name()); + "target linkage%s", + platform->Name(), hint); } // And then we invoke the factory, placing the result into the mapping. diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index efc893818d03a20d6bd65b7dc1da72ea5da5ceb0..92d1ca4ba5da802a5f1c544017ac52dda38e9b1d 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -42,8 +42,8 @@ void ComputationLayout::SetToDefaultLayout() { } bool ComputationLayout::LayoutIsSet() const { - return std::all_of(parameter_layouts_.begin(), parameter_layouts_.end(), - [](const ShapeLayout& s) { return s.LayoutIsSet(); }) && + return absl::c_all_of(parameter_layouts_, + [](const ShapeLayout& s) { return s.LayoutIsSet(); }) && result_layout_.LayoutIsSet(); } diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index e0165d557deb390636b42df12f099792042f1b65..5e26a63cebfa9b2e50f4b13335c10c246999d4df 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -349,11 +349,12 @@ Status AddCopiesForAliasedInputOutputs(HloModule* module) { ShapeTree param_indices_to_copy(param->shape()); module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { - if (param_number == param->parameter_number()) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + if (alias.parameter_number == param->parameter_number()) { param_has_alias = true; - *(param_indices_to_copy.mutable_element(param_index)) = true; + *(param_indices_to_copy.mutable_element(alias.parameter_index)) = + true; *(output_indices_to_copy.mutable_element(output_index)) = true; } }); @@ -395,13 +396,14 @@ Status AddCopiesForAliasedInputOutputs(HloModule* module) { // Add control dependencies between the input/output copies. TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& input_index) -> Status { - if (!copied_parameters[param_number]) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) -> Status { + if (!copied_parameters[alias.parameter_number]) { return Status::OK(); } HloInstruction* from = - copied_parameters[param_number]->element(input_index); + copied_parameters[alias.parameter_number]->element( + alias.parameter_index); HloInstruction* to = output_copy_tree.element(output_index); TF_RET_CHECK(from != nullptr); @@ -539,10 +541,9 @@ class CopyRemover { } std::vector values = buffer.values(); - std::sort(values.begin(), values.end(), - [this](const HloValue* a, const HloValue* b) { - return ordering_.IsDefinedBefore(*a, *b); - }); + absl::c_sort(values, [this](const HloValue* a, const HloValue* b) { + return ordering_.IsDefinedBefore(*a, *b); + }); // Create a list containing all of the values in the buffer. AddValueList(values, &value_to_node); @@ -842,12 +843,11 @@ class CopyRemover { copy_value_node->next->prev = operand_node; // Patch up uses. Remove use of copy from operand_node uses. - auto it = - std::find_if(operand_node->uses.begin(), operand_node->uses.end(), - [copy_value_node](const HloUse* use) { - return use->instruction == - copy_value_node->value->defining_instruction(); - }); + auto it = absl::c_find_if( + operand_node->uses, [copy_value_node](const HloUse* use) { + return use->instruction == + copy_value_node->value->defining_instruction(); + }); CHECK(it != operand_node->uses.end()); operand_node->uses.erase(it); diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index e4e9d7ba05c115be9dd0eb53ebd7de208d514efb..4d4074943e3bf9f6f2a37abc63f037c2dab06e0f 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1376,9 +1376,11 @@ TEST_F(CopyInsertionTest, CrossingParameters) { builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 4); @@ -1409,9 +1411,11 @@ TEST_F(CopyInsertionTest, ParametersAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1475,7 +1479,8 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -1516,7 +1521,8 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1557,7 +1563,8 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index f49b5110be5c4bab63b423e5ed2e67bc1828f6e3..a197bdddc88dafe9b1a5769da7bc1bdc54f84dc2 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -1,6 +1,14 @@ # Description: # LLVM-based CPU backend for XLA. +load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") +load( + "//third_party/mkl:build_defs.bzl", + "mkl_deps", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load(":build_defs.bzl", "runtime_copts") + licenses(["notice"]) # Apache 2.0 package( @@ -14,15 +22,6 @@ package_group( ], ) -load(":build_defs.bzl", "runtime_copts") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") -load( - "//third_party/mkl:build_defs.bzl", - "mkl_deps", -) - # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", @@ -114,6 +113,7 @@ cc_library( "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -241,6 +241,7 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor/host:host_stream", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -364,15 +365,33 @@ cc_library( ], ) +cc_library( + name = "tiled_dot_emitter", + srcs = ["tiled_dot_emitter.cc"], + hdrs = ["tiled_dot_emitter.h"], + deps = [ + ":vector_support_library", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "dot_op_emitter", srcs = ["dot_op_emitter.cc"], - hdrs = ["dot_op_emitter.h"], + hdrs = [ + "dot_op_emitter.h", + ], deps = [ ":cpu_options", ":cpu_runtime", ":ir_emission_utils", ":target_machine_features", + ":tiled_dot_emitter", ":vector_support_library", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -631,6 +650,7 @@ cc_library( deps = [ ":runtime_matvec", "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//third_party/eigen3", ], ) @@ -1008,7 +1028,6 @@ tf_cc_test( size = "small", srcs = ["cpu_eigen_tensor_alignment_test.cc"], deps = [ - ":dot_op_emitter", ":ir_emission_utils", ":target_machine_features_fake", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index dbab839aee95b249bd993e7a0a63fd3af38f8537..a6d92ce10ee311704eb4e40cee1d14fa6fdb1e57 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -69,6 +69,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -244,6 +245,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloPassPipeline pipeline("HLO passes through layout assignment"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pipeline.AddPass(); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( @@ -302,7 +304,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass( [&](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot, *target_machine_features) + return DotImplementationCanHandleTranspose(dot, + *target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc index 8727c72b6e42517b1859e98ecadb41bbceed761c..485769a373acf5ae70c471b1a5dfcfb20ff772ef 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -28,37 +27,6 @@ namespace { class CpuEigenTensorAlignmentTest : public ::testing::Test {}; -TEST_F(CpuEigenTensorAlignmentTest, EigenDotAlignment) { - string hlo_string = R"( -HloModule DotOperation - -ENTRY DotOperation { - arg0 = f32[5,256] parameter(0) - arg1 = f32[256,1024] parameter(1) - ROOT dot = f32[5,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_string)); - - HloInstruction* dot = module->entry_computation()->root_instruction(); - - TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( - [](int64 size) { return 1; }); - - EXPECT_FALSE( - PotentiallyImplementedAsEigenDot(*dot, target_machine_with_no_alignment)); - - TargetMachineFeaturesWithFakeAlignmentLogic - target_machine_with_full_alignment([](int64 size) { - return TargetMachineFeatures::kEigenExpectedTensorAlignment; - }); - - EXPECT_TRUE(PotentiallyImplementedAsEigenDot( - *dot, target_machine_with_full_alignment)); -} - TEST_F(CpuEigenTensorAlignmentTest, EigenConvAlignment) { string hlo_string = R"( HloModule ConvOperation diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 527df0bd1c23bba74f32226e5622fed32f7dcf84..46c1d4c38e01b6d3b79699fc0cdc73868c974353 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -332,7 +332,7 @@ TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {8}); - Shape starts_shape = ShapeUtil::MakeShape(F32, {2}); + Shape starts_shape = ShapeUtil::MakeShape(F32, {}); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8}); Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8}); Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -340,13 +340,15 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { HloInstruction::CreateParameter(0, param_shape, "param")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, starts_shape, "starts")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); HloInstruction* broadcast2 = builder.AddInstruction( HloInstruction::CreateBroadcast(broadcast_shape, param0, {1})); HloInstruction* reshape3 = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, broadcast2)); HloInstruction* dynamic_slice4 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, reshape3, param1, {4, 4})); + dynamic_slice_shape, reshape3, {param1, param2}, {4, 4})); builder.AddInstruction(HloInstruction::CreateUnary( dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); @@ -356,7 +358,8 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { RunFusionAndCheckOpcodesWereFused( module.get(), {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape, - HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter}); + HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter}); } TEST_F(OpcodeFusionTest, Broadcast_Negate) { @@ -381,14 +384,16 @@ TEST_F(OpcodeFusionTest, Broadcast_Negate) { TEST_F(OpcodeFusionTest, DynamicSlice_Negate) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {4}); - Shape slice_shape = ShapeUtil::MakeShape(F32, {1}); + Shape slice_shape = ShapeUtil::MakeShape(F32, {}); Shape result_shape = ShapeUtil::MakeShape(F32, {2}); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, slice_shape, "starts")); - HloInstruction* dynamic_slice2 = builder.AddInstruction( - HloInstruction::CreateDynamicSlice(result_shape, param0, param1, {2})); + HloInstruction* dynamic_slice2 = + builder.AddInstruction(HloInstruction::CreateDynamicSlice( + result_shape, param0, + std::initializer_list({param1}), {2})); builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, dynamic_slice2)); @@ -548,28 +553,36 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); + std::vector slice_indices, update_indices; + for (int i = 0; i < 3; ++i) { + slice_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + 1 + i, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + update_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + 5 + i, ShapeUtil::MakeShape(U32, {}), "update_indices"))); + } HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( slice_shape, builder.AddInstruction( HloInstruction::CreateParameter(0, full_shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + slice_indices, /*slice_sizes=*/{10, 1, 1000})); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_shape, builder.AddInstruction( - HloInstruction::CreateParameter(2, full_shape, "to_update")), - slice, - builder.AddInstruction(HloInstruction::CreateParameter( - 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); + HloInstruction::CreateParameter(4, full_shape, "to_update")), + slice, update_indices)); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( - module.get(), {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, - HloOpcode::kParameter, HloOpcode::kParameter, - HloOpcode::kParameter, HloOpcode::kParameter}); + module.get(), + {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}); } TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index 1cb154e9ca3bbfd9742df96316272e0fc653da36..95b8025f873c56bea063ff258d4abd6614257d85 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -46,8 +46,7 @@ static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { for (auto* user : instruction->users()) { optional operand_idx = ProfitableToMakeDotOperandColumnMajor(*user); if (!operand_idx || user->operand(*operand_idx) != instruction || - std::count(user->operands().begin(), user->operands().end(), - instruction) != 1) { + absl::c_count(user->operands(), instruction) != 1) { return false; } } @@ -94,60 +93,38 @@ static Shape ColMajorShape(const Shape& old_shape) { return new_shape; } +static bool OperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& instr, + const TargetMachineFeatures& target_machine_features) { + if (instr.opcode() == HloOpcode::kConvolution) { + return PotentiallyImplementedAsEigenConvolution(instr, + target_machine_features); + } else if (instr.opcode() == HloOpcode::kDot) { + return DotOperandsAndResultMustHaveRowMajorLayout(instr, + target_machine_features); + } + return false; +} + Status CpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { ShouldMakeOperandColMajorCache cache; const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction, - target_machine_features_)) { - const HloInstruction* convolution = instruction; - const HloInstruction* lhs_instruction = convolution->operand(0); - const HloInstruction* rhs_instruction = convolution->operand(1); - - // In order to implement `convolution` with Eigen convolution, the layouts - // of the input, filter, and output need to be row-major. - // - // These constraints are not hard constraints. Ideally, we should decide - // which layouts to choose according to some cost model. - Shape output_shape(RowMajorShape(convolution->shape())); - Shape input_shape(RowMajorShape(lhs_instruction->shape())); - Shape filter_shape(RowMajorShape(rhs_instruction->shape())); - - // Set layouts of the instructions' shapes. - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, convolution, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, convolution, 1)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(output_shape, convolution)); + if (OperandsAndResultMustHaveRowMajorLayout(*instruction, + target_machine_features_)) { + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + RowMajorShape(instruction->shape()), instruction)); + for (int i = 0; i < instruction->operand_count(); i++) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + RowMajorShape(instruction->operand(i)->shape()), instruction, i)); + } } else if (optional op_idx = ShouldMakeOperandColumnMajor(&cache, *instruction)) { const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( ColMajorShape(op->shape()), instruction, *op_idx)); - } else if (PotentiallyImplementedAsEigenDot(*instruction, - target_machine_features_)) { - const HloInstruction* dot = instruction; - // In order to implement `dot` with Eigen dot, the layouts of the lhs, - // rhs, and output need to be row-major. - // - // These constraints are not hard constraints. Ideally, we should decide - // which layouts to choose according to some cost model. - Shape output_shape(RowMajorShape(dot->shape())); - - const HloInstruction* lhs_instruction = dot->operand(0); - Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); - - const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); - - // Set layouts of the instructions' shapes. - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot)); } else { for (int64 operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index 92debb83e33b1400a59e5eef0f90971392ab7b22..ff654c83d61e7cc09ac7839feccaf2bc9cb3c63c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -23,8 +23,8 @@ namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; -const char* const kXlaEnableExperimentalLlvmIrGemm = - "xla_enable_experimental_llvm_ir_gemm"; +const char* const kXlaForceEnableExperimentalLlvmIrGemm = + "xla_force_enable_experimental_llvm_ir_gemm"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -57,10 +57,10 @@ absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config) { return absl::nullopt; } -bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { +bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); - return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; + return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0; } static absl::string_view RemoveSuffix(absl::string_view str, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 47c7eb13b6e4cc05a23f82b8d2a25249f4b82ac0..99e6702d14aed8ffb148adec2bdd02dbc7c3c7e3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -26,7 +26,7 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); -bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); +bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config); absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config); absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc index 3ae64142cd7e32d3aa8d50870efaf94698c06440..c3c6847b7b77e2fb0470630815de9f5d7a6c5b9c 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.cc +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -77,17 +77,16 @@ StatusOr Disassembler::DisassembleObjectFile( } // Sort the symbols in increasing address order. - std::sort( - symbols.begin(), symbols.end(), - [](const llvm::object::SymbolRef& a, const llvm::object::SymbolRef& b) { - // getAddress returns a Expected object. Assert there is no error - // before extracting the address. - llvm::Expected a_address_or_error = a.getAddress(); - CHECK(a_address_or_error); - llvm::Expected b_address_or_error = b.getAddress(); - CHECK(b_address_or_error); - return a_address_or_error.get() < b_address_or_error.get(); - }); + absl::c_sort(symbols, [](const llvm::object::SymbolRef& a, + const llvm::object::SymbolRef& b) { + // getAddress returns a Expected object. Assert there is no error + // before extracting the address. + llvm::Expected a_address_or_error = a.getAddress(); + CHECK(a_address_or_error); + llvm::Expected b_address_or_error = b.getAddress(); + CHECK(b_address_or_error); + return a_address_or_error.get() < b_address_or_error.get(); + }); // Construct ArrayRef pointing to section contents. llvm::StringRef section_content_string; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 1525a33af7a490f99762733a677d7913c480200d..b018e0cd462ecc8d7924ecc5e1ecc3d27c6fe2c6 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" @@ -41,931 +42,163 @@ namespace xla { using llvm_ir::SetToFirstInsertPoint; namespace cpu { - namespace { -// Provides tiled access to an in-memory rank 2 array. -class MemoryTile { - public: - // Constructs a MemoryTile that can operate on tiles consisting of - // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at - // `major_dim_offset` in the major dimension. The tile size along the minor - // dimension is the vector size, and that is implicitly determined by `vsl`. - MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b, - llvm::Value* matrix, int64 matrix_size_along_minor_dim, - llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl), b_(b) { - pointers_.reserve(tile_size_along_major_dim); - for (int64 i = 0; i < tile_size_along_major_dim; i++) { - llvm::Value* total_offset = - b->CreateMul(b->getInt64(matrix_size_along_minor_dim), - b->CreateAdd(b->getInt64(i), major_dim_offset)); - pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); - } - } - - // Load a tile consisting of `tile_size_along_major_dim` vectors from position - // {major: `major_dim_offset`, minor: `minor_dim_offset`}. - // - // Note: `major_dim_offset` is a parameter to the constructor. - std::vector LoadTile(llvm::Value* minor_dim_offset) const { - std::vector result; - result.reserve(pointers_.size()); - for (const auto& pointer : pointers_) { - result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); - } - return result; - } - - // Stores `tile` to position {major: `major_dim_offset`, minor: - // `minor_dim_offset`}. - // - // Note: `major_dim_offset` is a parameter to the constructor. - void StoreTile(absl::Span tile, - llvm::Value* minor_dim_offset) const { - CHECK_EQ(tile.size(), pointers_.size()); - for (int64 i = 0; i < pointers_.size(); i++) { - vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); - } - } +// Returns true if we should call into multi-threaded Eigen routines. +bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) { + return config.debug_options().xla_cpu_multi_thread_eigen(); +} - // Loads a tile of size [`tile_size_along_major_dim`, - // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, - // minor: `minor_dim_offset`} and then broadcasts each element into a vector - // of size vsl_.vector_size(). The (i,j)'th element of the return value is - // the (i,j)'th element in the tile broadcasted into an LLVM vector. - // - // Note: `major_dim_offset` is a parameter to the constructor. - std::vector> LoadBroadcastTile( - llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { - std::vector> result; - result.resize(pointers_.size()); - for (int64 i = 0; i < pointers_.size(); i++) { - for (int64 j = 0; j < tile_size_along_middle_dim; j++) { - result[i].push_back(vsl_->LoadBroadcast( - pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j)))); - } - } - return result; +// Represents a dot operation. We use this in lieu of an `HloInstruction` +// because we want to be able to create this for the "inner" dot operation in a +// batch dot, for which there is no separate HLO instruction. +struct DotInfo { + Shape lhs_shape; + Shape rhs_shape; + Shape result_shape; + DotDimensionNumbers dim_nums; + + explicit DotInfo(const HloInstruction& instr) { + CHECK_EQ(instr.opcode(), HloOpcode::kDot); + lhs_shape = instr.operand(0)->shape(); + rhs_shape = instr.operand(1)->shape(); + result_shape = instr.shape(); + dim_nums = instr.dot_dimension_numbers(); } - - private: - VectorSupportLibrary* vsl_; - llvm::IRBuilder<>* b_; - std::vector pointers_; }; -// The base class for the classes representing the GEMV emitter configurations. -// -// The IR emitted (modulo the LLVM values representing the input and output -// buffers) by the row major and column major GEMV emitters should be a function -// of their configuration. This is important because their configuration is -// used as a key to cache the generated IR. -class GemvConfig { - public: - // Mixin for convenience. - template - struct User { - public: - PrimitiveType scalar_type() const { - return derived().config().scalar_type(); - } - int64 tile_rows() const { return derived().config().tile_rows(); } - int64 tile_cols() const { return derived().config().tile_cols(); } - int64 m() const { return derived().config().m(); } - int64 k() const { return derived().config().k(); } - int64 has_addend() const { return derived().config().has_addend(); } - - private: - const T& derived() const { return *static_cast(this); } - }; - - PrimitiveType scalar_type() const { return scalar_type_; } - int64 tile_rows() const { return tile_rows_; } - int64 tile_cols() const { return tile_cols_; } - int64 m() const { return m_; } - int64 k() const { return k_; } - bool has_addend() const { return has_addend_; } - - string GetCacheKey() const { - return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", - tile_rows(), "_", tile_cols(), "_", m(), "_", k(), - has_addend() ? "_with_addend" : ""); - } - - protected: - explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, - int64 tile_cols, int64 m, int64 k, bool has_addend) - : name_(std::move(name)), - scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), - has_addend_(has_addend) {} - - private: - string name_; - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; - bool has_addend_; +// Dictates how a dot operation is implemented. +enum class DotImplementationStrategy { + // The dot operation is lowered into LLVM IR that implements a naive nested + // loop that computes the result one element at a time. This is our + // "fallback"; we don't really want this to kick in for any non-trival dot + // operation. + kNaiveLlvmIr, + + // The dot operation is lowered into LLVM IR that implements a tiled + // Matrix*Vector operation. This strategy also allows fusing in a bias add + // into the dot. The matrix can be row major or column major, both are + // supported. + kTiledLlvmIrGemv, + + // The dot operation is lowered into LLVM IR that implemetns a tiled + // Matrix*Matrix operation. No fusions are supported. The two inputs + // and the output have to be row major. + kTiledLlvmIrGemm, + + // The dot operation is lowered into a call into an Eigen routine. No fusions + // are supported today. The two inputs and the output have to be row major. + // However, we do allow transposing either the LHS or the RHS as part of the + // GEMM -- we expose this flexibility as flexibility in the contraction + // dimensions, but we can also see this as flexibility in the input layouts. + kEigen, }; -// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the -// layout of the vector does not matter). This implementation uses a tiling -// scheme to improve performance. -// -// We logically separate the LHS matrix into four segments: -// -// +----------------------+---+ -// | | | -// | | | -// | A | B | -// | | | -// | | | -// | | | -// +----------------------+---+ -// | C | D | -// +----------------------+---+ -// -// where A is the largest submatrix of the LHS that can be evenly dividied into -// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: -// -// +---+---+---+---+ +--+--+--+--+ -// |M00|M10|M20|M30| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M01|M11|M21|M31| and |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M02|M12|M22|M32| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M03|M13|M23|M33| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// -// (Legend: rows are horizontal and columns are vertical; and each column is one -// llvm::Value of a vector type) -// -// where: -// -// a. The left tile is from the column major left matrix. -// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] -// vector loaded from the RHS vector. -// -// As we iterate through the column dimension, we compute the change to the -// result vector by an elementwise multiplication between the two tiles above -// followed by a reduction along the major dimension: -// -// +-----------------------------------+ -// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | -// +-----------------------------------+ -// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | -// Result[R:R+4] += +-----------------------------------+ -// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | -// +-----------------------------------+ -// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | -// +-----------------------------------+ -// -// Where R is the starting row for the tile. -// -// We have an inner epilogue loop to deal with the "C" submatrix and an outer -// epilogue loop to deal with the B,D submarix. -// -// TODO(sanjoy): We should investigate if using gather loads and scatter stores -// can be used here have the same inner loop for both column-major and row-major -// matrix-vector products. -class ColumnMajorMatrixVectorProductEmitter - : public GemvConfig::User { - public: - class Config : public GemvConfig { - public: - explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, - int64 m, int64 k, bool has_addend) - : GemvConfig(/*name=*/"col_major_gemv", scalar_type, - /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, - /*k=*/k, /*has_addend=*/has_addend) {} - }; - - ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result, - llvm::IRBuilder<>* b) - : config_(config), - lhs_(lhs), - rhs_(rhs), - addend_(addend), - result_(result), - b_(b), - ksl_(b_), - vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") { - CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); - CHECK(!has_addend() || addend != nullptr); - } - - void Emit(); - - const Config& config() const { return config_; } - - private: - void EmitOuterLoopBody(llvm::Value* column, int64 column_count, - bool is_first_column); - - MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { - return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/m(), - /*major_dim_offset=*/column_start, - /*tile_size_along_major_dim=*/column_count); - } +// Returns the implementation strategy for a dot with the configuration +// `dot_info`. +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features); - // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous - // sequence of `count` values, each one broadcasted to the vector width. - std::vector LoadRhsTile(llvm::Value* offset, int64 count) { - llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); - std::vector result; - result.reserve(count); - for (int64 i = 0; i < count; i++) { - result.push_back(vsl_.LoadBroadcast(base_pointer, i)); - } - return result; - } - - void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, - const std::vector& rhs_tile, - int64 columns, bool is_first_column); - - void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, - bool is_first_tiled_column); - - Config config_; - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* addend_; - llvm::Value* result_; - llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; - VectorSupportLibrary vsl_; -}; - -void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( - llvm::Value* column, int64 column_count, bool is_first_column) { - MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, - /*column_count=*/column_count); - - std::vector rhs_tile = - LoadRhsTile(column, /*count=*/column_count); - EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, - /*columns=*/column_count, is_first_column); - EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); -} - -void ColumnMajorMatrixVectorProductEmitter::Emit() { - // See the comment on the class declaration for the algorithm used here. - int64 column_remainder = k() % tile_cols(); - int64 column_limit = k() - column_remainder; - - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols(), is_first_column); - }); - - if (column_remainder != 0) { - EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, - column_limit == 0); - } -} - -void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, - int64 columns, bool is_first_column) { - int64 row_limit = m() - (m() % tile_rows()); - - ksl_.For( - "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows(), [&](llvm::Value* row) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = - is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) - : vsl_.GetZeroVector()) - : vsl_.LoadVector(result_, row); - for (int i = 0; i < columns; i++) { - accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); - } - vsl_.StoreVector(accumulator, result_, row); - }); -} - -void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( - llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { - int64 row_start = m() - (m() % tile_rows()); - if (row_start == m()) { - return; - } - - llvm::Value* columns_llvm = b_->getInt64(columns); - - // for (col = current_tile_col; col < (columns + current_tile_col); col++) - // for (row = row_start, row < m_; row++) { - // result[row] += lhs[row, col] * rhs[col] - // // Also take into account that if col is 0 then result[row] is not - // // initialized. - // } - - ksl_.For( - "dot.inner.epilg.outer", /*start=*/current_tile_col, - /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), - /*step=*/1, /*peel_first_iteration=*/false, - [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { - llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); - llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); - llvm::Value* lhs_base_pointer = - vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For( - "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), - /*step=*/1, [&](llvm::Value* scalar_row) { - llvm::Value* product = vsl_.Mul( - vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); - llvm::Value* setting_result_first_time = b_->CreateAnd( - is_first_scalar_col, b_->getInt1(is_first_tiled_column)); - ksl_.If( - setting_result_first_time, - /*true_block_generator=*/ - [&]() { - if (addend_) { - vsl_.StoreScalar( - vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), - product), - result_, scalar_row); - } else { - vsl_.StoreScalar(product, result_, scalar_row); - } - }, - /*false_block_generator=*/ - [&]() { - vsl_.StoreScalar( - vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), - result_, scalar_row); - }); - }); - }); -} - -// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the -// layout of the vector does not matter). This implementation uses a tiling -// scheme to improve performance. -// -// We logically separate the LHS matrix into four segments: -// -// +----------------------+---+ -// | | | -// | | | -// | A | B | -// | | | -// | | | -// | | | -// +----------------------+---+ -// | C | D | -// +----------------------+---+ -// -// where A is the largest submatrix of the LHS that can be evenly dividied into -// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: -// -// +---+---+---+---+ -// |M00|M10|M20|M30| -// +---+---+---+---+ +--+--+--+--+ -// |M01|M11|M21|M31| and |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M02|M12|M22|M32| -// +---+---+---+---+ -// |M03|M13|M23|M33| -// +---+---+---+---+ -// -// (Legend: rows are horizontal and columns are vertical; and each row is one -// llvm::Value of a vector type) -// -// where: -// -// a. The left tile is loaded from the row major left matrix. -// b. The right vector is loaded from the RHS vector. -// -// We keep 4 vector accumulators accumulating the following four vector -// expressions as we iterate over the row dimension: -// -// +------+------+------+------+ -// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) -// +------+------+------+------+ -// -// In the end we do a horizontal reduction over these 4 vector accumulators to -// get 4 values in the result vector. -// -// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer -// epilogue loop to deal with the C,D submatrix. -class RowMajorMatrixVectorProductEmitter - : public GemvConfig::User { +// Helper class for emitting LLVM IR to perform the dot operation. +class DotOpEmitter { public: - class Config : public GemvConfig { - public: - explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, - int64 m, int64 k, bool has_addend) - : GemvConfig(/*name=*/"row_major_gemv", scalar_type, - /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, - /*k=*/k, /*has_addend=*/has_addend) {} - }; - - RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result, llvm::IRBuilder<>* b) - : config_(config), - lhs_(lhs), - rhs_(rhs), - addend_(addend), - result_(result), - b_(b), - ksl_(b_), - vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") { - CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); - CHECK(!has_addend() || addend != nullptr); - } - - void Emit(); - - const Config& config() const { return config_; } + explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); + + // Emits the IR to perform the dot operation. + Status Emit(); private: - MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { - return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/k(), - /*major_dim_offset=*/row_start, - /*tile_size_along_major_dim=*/row_count); - } - - void EmitOuterLoopBody(llvm::Value* row, int64 row_count); - - void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, - std::vector* vector_accumulators); - - void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, - std::vector* scalar_accumulators); - - Config config_; - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* addend_; - llvm::Value* result_; - llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; - VectorSupportLibrary vsl_; -}; + // Emits instructions to perform a scalar dot product (a multiply of the + // LHS and RHS) and store the results in the target. + Status EmitScalarDot(); -void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, - int64 row_count) { - MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, - /*row_count=*/row_count); - std::vector vector_accumulators; - std::vector scalar_accumulators; - for (int i = 0; i < row_count; i++) { - vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); - scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); - } - EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, - &vector_accumulators); - EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, - &scalar_accumulators); - - std::vector accumulator_values; - std::transform( - vector_accumulators.begin(), vector_accumulators.end(), - std::back_inserter(accumulator_values), - [](const VectorVariable& vector_var) { return vector_var.Get(); }); - - std::vector horizontal_sums; - if (row_count == vsl_.vector_size()) { - if (addend_) { - horizontal_sums = vsl_.ComputeHorizontalSums( - std::move(accumulator_values), vsl_.LoadVector(addend_, row)); - } else { - horizontal_sums = - vsl_.ComputeHorizontalSums(std::move(accumulator_values)); - } - } else { - horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); - } - - for (int i = 0; i < row_count; i++) { - llvm::Value* result_value = - vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); - llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row); - if (addend_ && row_count != vsl_.vector_size()) { - result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); - } - vsl_.StoreScalar(result_value, result_, offset); - } -} + // Emits a call to the CPU runtime to perform the matrix multiply. + Status EmitCallToRuntime(); -void RowMajorMatrixVectorProductEmitter::Emit() { - // See the comment on the class declaration for the algorithm used here. - int64 row_remainder = m() % tile_rows(); - int64 row_limit = m() - row_remainder; + // Represents the dimensions of a matrix-matrix multiply operation. + struct MatMultDims { + // The number of rows in the LHS. + int64 m; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); - - if (row_remainder != 0) { - EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); - } -} - -void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - MemoryTile* lhs_memory_tile, int64 rows, - std::vector* vector_accumulators) { - int64 column_limit = k() - (k() % tile_cols()); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols(), [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set( - vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); -} - -void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( - llvm::Value* current_tile_row, int64 rows, - std::vector* scalar_accumulators) { - int64 column_start = k() - (k() % tile_cols()); - if (column_start == k()) { - return; - } - - for (int r = 0; r < rows; r++) { - llvm::Value* total_offset = b_->CreateMul( - b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); - llvm::Value* lhs_base_pointer = - vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For( - "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), - /*step=*/1, [&](llvm::Value* scalar_col) { - llvm::Value* product = - vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), - vsl_.LoadScalar(rhs_, scalar_col)); - llvm::Value* old_value = (*scalar_accumulators)[r].Get(); - (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); - }); - } -} + // The number of columns in the LHS, which is also must be equal to the + // number of rows in the RHS. + int64 k; -// This class implements a tiled matrix multiplication algorithm, intended for -// multiplying small matrices that don't need cache tiling. -// -// In the future this can be used as the innermost GEBP loop in a GEMM kernel as -// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of -// high-performance matrix multiplication." ACM Transactions on Mathematical -// Software (TOMS) 34.3 (2008): 12.". -// -// This only supports canonical dot operations (i.e. where the lhs contraction -// dimension is 1 and the rhs contraction dimension is 0) over row major -// matrices. -class TiledSmallGemmEmitter { - public: - // Describe the dimensions of the kernel. - class Dimensions { - public: - explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + // The number of columns on the RHS. + int64 n; - int64 m() const { return m_; } - int64 k() const { return k_; } - int64 n() const { return n_; } + // True if the LHS matrix is column major. + bool lhs_column_major; - string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } + // True if the LHS contraction dimension is not 1. + bool lhs_non_canonical; - private: - const int64 m_; - const int64 k_; - const int64 n_; - }; + // True if the RHS matrix is column major. + bool rhs_column_major; - // Represents the configuration of the emitter. The LLVM IR emitted by the - // emitter, modulo the LLVM values holding the input and output buffers, must - // be a function of the instance of `Config` passed to it. - // - // `dims` holds the matrix multiplication dimensions. - // - // `max_vectorization_width` is the maximum vector width (i.e. the width of - // the largest vector register we will use). This can be larger than the - // largest vector register supported by the machine -- LLVM will legalize - // these large vector widths into legally sized vectors. - // - // `max_vector_count` is the maximum number of vectors of size - // `max_vectorization_width` that we will attempt to process at once. - // - // `min_vectorization_width` is the smallest vector width the emitter will use - // -- below that it will devolve to using a scalar loop. - // - // The innermost reduction loop executes the matrix multiply in tiles of size - // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, - // ] in the RHS. - class Config { - public: - explicit Config(PrimitiveType scalar_type, Dimensions dims, - int64 max_vectorization_width, int64 max_vector_count, - int64 min_vectorization_width, int64 tile_size_m, - int64 tile_size_k) - : scalar_type_(scalar_type), - dims_(dims), - max_vectorization_width_(max_vectorization_width), - max_vector_count_(max_vector_count), - min_vectorization_width_(min_vectorization_width), - tile_size_m_(tile_size_m), - tile_size_k_(tile_size_k) {} - - string GetCacheKey() const { - return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", - dims().ToString(), "_", max_vectorization_width(), - "_", min_vectorization_width(), "_", tile_size_m(), - "_", tile_size_k()); - } + // True if the RHS contraction dimension is not 0. + bool rhs_non_canonical; - PrimitiveType scalar_type() const { return scalar_type_; } - Dimensions dims() const { return dims_; } - int64 max_vectorization_width() const { return max_vectorization_width_; } - int64 max_vector_count() const { return max_vector_count_; } - int64 min_vectorization_width() const { return min_vectorization_width_; } - - int64 tile_size_m() const { return tile_size_m_; } - int64 tile_size_k() const { return tile_size_k_; } - - private: - PrimitiveType scalar_type_; - Dimensions dims_; - int64 max_vectorization_width_; - int64 max_vector_count_; - int64 min_vectorization_width_; - int64 tile_size_m_; - int64 tile_size_k_; + // True if the result matrix is column major. + bool target_column_major; }; - // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies - // `lhs` with `rhs` and stores the result in `result`. - explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* b) - : lhs_(lhs), - rhs_(rhs), - result_(result), - config_(config), - b_(b), - ksl_(b_) { - CHECK(max_vectorization_width() > 0 && - IsPowerOfTwo(static_cast(max_vectorization_width()))); - CHECK_GT(max_vector_count(), 0); - CHECK(min_vectorization_width() > 0 && - IsPowerOfTwo(static_cast(min_vectorization_width()))); - CHECK_GE(max_vectorization_width(), min_vectorization_width()); - CHECK_GT(tile_size_k(), 0); - } - - void Emit(); - - private: - // The HandleResiduesOnX helpers split the iteration space for dimension X - // into a multiple of the tile size on dimension X and an epilogue. These - // helpers ultimately call into `EmitTiledGemm` for emitting the - // tiled GEMM kernel. - - void HandleResiduesOnN(); - void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, - llvm::Value* n_end); - void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, - llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end); - - // This emits a tiled GEMM kernel. For a detailed description see the comment - // on the implementation. - void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, - llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end, - int64 tile_size_m, llvm::Value* m_start, - llvm::Value* m_end); - - llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); } - - Config config() const { return config_; } - Dimensions dims() const { return config().dims(); } - - int64 max_vectorization_width() const { - return config().max_vectorization_width(); - } - int64 max_vector_count() const { return config().max_vector_count(); } - int64 min_vectorization_width() const { - return config().min_vectorization_width(); - } - int64 tile_size_m() const { return config().tile_size_m(); } - int64 tile_size_k() const { return config().tile_size_k(); } - PrimitiveType scalar_type() const { return config().scalar_type(); } - - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* result_; - Config config_; - + // Get the MatMultDims instance for the dot product this DotOpEmitter + // represents. Precondition: the dot is of rank 2 (and thus its operands are + // of rank 2 as well). + MatMultDims GetMatMultDims() const; + + // Lowers the dot operation as a tiled Matrix*Vector loop. + void EmitTiledLlvmIrGemv(); + + // Lowers the dot operation as a tiled Matrix*Matrix loop. + void EmitTiledLlvmIrGemm(); + + // Lowers the dot operation as a naive nested loop that computes the result + // one element at a time. + void EmitNaiveLlvmIrGemm(); + + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector + // registers. + int64 GetGemvTilingFactor() const { + const int64 kDefaultTilingFactor = 8; + return options::LlvmIrGemvTilingFactor(hlo_module_config_) + .value_or(kDefaultTilingFactor); + } + + std::tuple GetGemmTileSize() const { + // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz + // + // TODO(b/80093688): Tune for other architectures and centralize this + // information in one place. + const std::tuple kDefaultTileSize = + std::tuple(11, 9, 1); + return options::LlvmIrGemmTileSize(hlo_module_config_) + .value_or(kDefaultTileSize); + } + + DotInfo dot_info_; + string dot_hlo_name_; + const llvm_ir::IrArray& target_array_; + const llvm_ir::IrArray& lhs_array_; + const llvm_ir::IrArray& rhs_array_; + const llvm_ir::IrArray* addend_array_; + llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; + const HloModuleConfig& hlo_module_config_; + const TargetMachineFeatures& target_machine_features_; }; - -void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } - -void TiledSmallGemmEmitter::HandleResiduesOnN() { - // We can only iterate the `n` dimension for an extent that is divisible by - // the vectorization width. So we emit an outer loop that first processes the - // largest extent in `n` that is divisible by max_vectorization_width, then - // the largest remaining extent that is divisible by max_vectorization_width / - // 2 etc. - - int64 current_vectorization_width = - max_vector_count() * max_vectorization_width(); - int64 current_vector_count = max_vector_count(); - - int64 n_start = 0; - while (n_start != dims().n() && - current_vectorization_width >= min_vectorization_width()) { - int64 n_end = dims().n() - (dims().n() % current_vectorization_width); - if (n_start != n_end) { - VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, - "gemm"); - HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); - n_start = n_end; - } - if (current_vector_count == 1) { - current_vectorization_width /= 2; - } else { - current_vector_count--; - current_vectorization_width = - current_vector_count * max_vectorization_width(); - } - } - - if (n_start != dims().n()) { - VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); - ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { - llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); - HandleResiduesOnK(&vsl, n_i, n_i_next); - }); - } -} - -void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, - llvm::Value* n_start, - llvm::Value* n_end) { - int64 k_start = 0; - int64 k_end = dims().k() - (dims().k() % tile_size_k()); - if (k_end != k_start) { - HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), - n_start, n_end); - k_start = k_end; - } - - if (k_start != dims().k()) { - HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), - GetInt64(dims().k()), n_start, n_end); - } -} - -void TiledSmallGemmEmitter::HandleResiduesOnM( - VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, - llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { - const int64 m_end = dims().m() - dims().m() % tile_size_m(); - EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), - GetInt64(0), GetInt64(m_end)); - - if (m_end != dims().m()) { - EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, - dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); - } -} - -// The loop structure is: -// -// Iterate over dimension M as m: -// Iterate over dimension N as n: -// Iterate over dimension K as k: -// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) -// -// I.e. a just a tiled version of a "naive" GEMM. -// -// The tiling scheme is as follows: -// -// Let the LHS be: -// -// +----+----+----+ -// | a0 | b0 | c0 | . -// +----+----+----+ . -// | a1 | b1 | c1 | . -// +----+----+----+ -// .. .. -// -// and the RHS be: -// -// +----+----+----+----+ -// | p0 | p1 | p2 | p3 | . -// +----+----+----+----+ . -// | q0 | q1 | q2 | q3 | . -// +----+----+----+----+ -// | r0 | r1 | r2 | r3 | . -// +----+----+----+----+ . -// ...... ...... -// -// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted -// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] -// matrix that we can increment the result matrix by. -// -// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank -// 3 array, L, of dimension [2,3,4]: -// -// L[0,_,_] * L[1,_,_] -// * -// +----+----+----+----+ * +----+----+----+----+ -// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | -// +----+----+----+----+ * +----+----+----+----+ -// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | -// +----+----+----+----+ * +----+----+----+----+ -// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | -// +----+----+----+----+ * +----+----+----+----+ -// -// -// Then we FMA L[0,_,_] with the RHS to get the first row of the result and -// L[1,_,_] with the RHS to get the second row of the result. For example, -// L[0,_,_] is computed as: -// -// +----+----+----+----+ +----+----+----+----+ -// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + -// +----+----+----+----+ +----+----+----+----+ -// -// +----+----+----+----+ +----+----+----+----+ -// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + -// +----+----+----+----+ +----+----+----+----+ -// -// +----+----+----+----+ +----+----+----+----+ -// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | -// +----+----+----+----+ +----+----+----+----+ -// -// to get: -// -// +-------------------+-------------------+-------------------+--------- -// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... -// +-------------------+-------------------+-------------------+--------- -void TiledSmallGemmEmitter::EmitTiledGemm( - VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, - llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, - int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { - ksl_.For( - "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { - MemoryTile result_memory_tile( - vsl, b_, /*matrix=*/result_, - /*matrix_size_along_minor_dim=*/dims().n(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/dims().k(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - ksl_.For( - "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { - TileVariable result_tile_var(vsl, - result_memory_tile.LoadTile(n_i)); - ksl_.For( - "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { - MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, - tile_size_k); - std::vector> lhs_tile = - lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); - std::vector rhs_tile = - rhs_memory_tile.LoadTile(n_i); - std::vector result_tile = - result_tile_var.Get(); - for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { - for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { - result_tile[r_m_i] = - vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], - result_tile[r_m_i]); - } - } - result_tile_var.Set(result_tile); - }); - - result_memory_tile.StoreTile(result_tile_var.Get(), n_i); - }); - }); -} - } // namespace -DotOpEmitter::DotOpEmitter(const HloInstruction& dot, +DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, @@ -974,7 +207,8 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) - : dot_(dot), + : dot_info_(std::move(dot_info)), + dot_hlo_name_(std::move(dot_hlo_name)), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), @@ -984,58 +218,9 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} -/* static */ Status DotOpEmitter::EmitDotOperation( - const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) { - PrimitiveType type = target_array.GetShape().element_type(); - TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); - DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array, - addend_array, executable_run_options_value, b, - hlo_module_config, target_machine_features); - return dot_emitter.Emit(); -} - -bool DotOpEmitter::EmitSmallGemmIfProfitable( - const DotOpEmitter::MatMultDims& mat_mult_dims) { - if (ShouldUseMultiThreadedEigen()) { - return false; - } - - if (!EnableExperimentalLlvmIrGemm()) { - // TODO(sanjoy): We should make these numbers micro-arch specific. - bool small_gemm = mat_mult_dims.k <= 128 && - ((mat_mult_dims.m <= 32 && mat_mult_dims.n <= 128) || - (mat_mult_dims.m <= 128 && mat_mult_dims.n <= 32)); - if (!small_gemm) { - return false; - } - } - - if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { - return false; - } - - PrimitiveType primitive_type = dot_.shape().element_type(); - - switch (primitive_type) { - default: - return false; - - case F32: - case F64: - case S32: - case S64: - break; - } - - if (!(mat_mult_dims.lhs_column_major == mat_mult_dims.rhs_column_major && - mat_mult_dims.rhs_column_major == mat_mult_dims.target_column_major)) { - return false; - } +void DotOpEmitter::EmitTiledLlvmIrGemm() { + PrimitiveType primitive_type = dot_info_.result_shape.element_type(); + MatMultDims mat_mult_dims = GetMatMultDims(); llvm::Value* lhs = lhs_array_.GetBasePointer(); llvm::Value* rhs = rhs_array_.GetBasePointer(); @@ -1050,9 +235,8 @@ bool DotOpEmitter::EmitSmallGemmIfProfitable( } int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - b_->CreateMemSet( - target, b_->getInt8(0), size_bytes, - target_machine_features_.minimum_alignment_for_allocation(size_bytes)); + b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes, + /*Align=*/1); int64 max_target_vector_width = target_machine_features_.vector_register_num_elements( @@ -1062,47 +246,28 @@ bool DotOpEmitter::EmitSmallGemmIfProfitable( std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = GetGemmTileSize(); - TiledSmallGemmEmitter::Config config( - /*scalar_type=*/primitive_type, - TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, - /*max_vectorization_width=*/max_target_vector_width, - /*max_vector_count=*/tile_size_n_in_vector_width, - /*min_vectorization_width=*/std::min(4, max_target_vector_width), - /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); - - VLOG(2) << "Emitting GEMM kernel in LLVM IR with config " - << config.GetCacheKey(); - const bool enable_fast_math = hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); - KernelSupportLibrary::EmitAndCallOutlinedKernel( + EmitSmallGemm( + /*scalar_type=*/primitive_type, + /*m=*/m, /*k=*/k, /*n=*/n, + /*max_vectorization_width=*/max_target_vector_width, + /*max_vector_count=*/tile_size_n_in_vector_width, + /*min_vectorization_width=*/std::min(4, max_target_vector_width), + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs, + /*rhs=*/rhs, /*result=*/target, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs, - rhs, target, - [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { - TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, - /*rhs=*/rhs, - /*result=*/target, b_); - small_gemm_emitter.Emit(); - }); - - return true; + /*optimize_for_size=*/optimize_for_size); } -bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { - if (dot_.shape().dimensions_size() != 2) { - return false; - } - - PrimitiveType primitive_type = dot_.shape().element_type(); +void DotOpEmitter::EmitTiledLlvmIrGemv() { + PrimitiveType primitive_type = dot_info_.result_shape.element_type(); - if (!primitive_util::IsFloatingPointType(primitive_type) && - !primitive_util::IsIntegralType(primitive_type)) { - return false; - } + CHECK(primitive_util::IsFloatingPointType(primitive_type) || + primitive_util::IsIntegralType(primitive_type)); MatMultDims mat_mult_dims = GetMatMultDims(); bool is_column_major_matrix_vector = false; @@ -1143,9 +308,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } } - if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return EmitSmallGemmIfProfitable(mat_mult_dims); - } + CHECK(is_column_major_matrix_vector || is_row_major_matrix_vector); int64 tiling_factor = GetGemvTilingFactor(); CHECK_GT(tiling_factor, 0); @@ -1177,44 +340,27 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - ColumnMajorMatrixVectorProductEmitter::Config config( + EmitColumnMajorGemv( /*scalar_type=*/primitive_type, /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor, - /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); - - KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, + /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, + /*result=*/result_op, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), - lhs_op, rhs_op, - addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, - llvm::Value* addend_op, llvm::Value* result_op) { - ColumnMajorMatrixVectorProductEmitter emitter( - config, lhs_op, rhs_op, addend_op, result_op, b_); - emitter.Emit(); - }); + /*optimize_for_size=*/optimize_for_size); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - RowMajorMatrixVectorProductEmitter::Config config( + EmitRowMajorGemv( /*scalar_type=*/primitive_type, - /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size, - /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); - - KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*tile_rows=*/tiling_factor, + /*tile_cols=*/vector_register_element_size, + /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, + /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, + /*result=*/result_op, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), - lhs_op, rhs_op, - addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, - llvm::Value* addend_op, llvm::Value* result_op) { - RowMajorMatrixVectorProductEmitter emitter(config, lhs_op, rhs_op, - addend_op, result_op, b_); - emitter.Emit(); - }); + /*optimize_for_size=*/optimize_for_size); } - - return true; } Status DotOpEmitter::Emit() { @@ -1240,11 +386,6 @@ Status DotOpEmitter::Emit() { // which performs the sum-of-products (the reduction loop) before storing // the result in the output buffer. - // This routine assumes that the dot operation is not in a parallelized - // enclosing computation. - CHECK( - dot_.parent()->root_instruction()->outer_dimension_partitions().empty()); - const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); @@ -1255,27 +396,41 @@ Status DotOpEmitter::Emit() { return EmitScalarDot(); } - if (EmitLlvmIrDotIfProfitable()) { - return Status::OK(); + switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_, + target_machine_features_)) { + case DotImplementationStrategy::kNaiveLlvmIr: + EmitNaiveLlvmIrGemm(); + return Status::OK(); + + case DotImplementationStrategy::kTiledLlvmIrGemv: + EmitTiledLlvmIrGemv(); + return Status::OK(); + + case DotImplementationStrategy::kTiledLlvmIrGemm: + EmitTiledLlvmIrGemm(); + return Status::OK(); + + case DotImplementationStrategy::kEigen: + return EmitCallToRuntime(); } +} +void DotOpEmitter::EmitNaiveLlvmIrGemm() { CHECK_EQ(addend_array_, nullptr); - if (PotentiallyImplementedAsEigenDot(dot_, target_machine_features_)) { - return EmitCallToRuntime(); - } + const Shape& lhs_shape = lhs_array_.GetShape(); + const Shape& rhs_shape = rhs_array_.GetShape(); + const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special // case where the reduction dimension is 0 for both LHS and RHS. This results // in a vector dot product producing a scalar. - int64 lhs_reduction_dimension = - dot_.dot_dimension_numbers().lhs_contracting_dimensions(0); - int64 rhs_reduction_dimension = - dot_.dot_dimension_numbers().rhs_contracting_dimensions(0); + int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0); + int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0); // Verify the reduction dimension in the two operands are the same size. - TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == - rhs_shape.dimensions(rhs_reduction_dimension)); + CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension), + rhs_shape.dimensions(rhs_reduction_dimension)); bool lhs_reduction_along_minor_dimension = lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0); @@ -1285,7 +440,7 @@ Status DotOpEmitter::Emit() { // Create loop nests which loop through the LHS operand dimensions and the RHS // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. - llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), b_); + llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_); llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( @@ -1390,8 +545,6 @@ Status DotOpEmitter::Emit() { // Set the IR builder insert point to the exit basic block of the outer most // loop. b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - - return Status::OK(); } Status DotOpEmitter::EmitScalarDot() { @@ -1438,7 +591,7 @@ Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded = ShouldUseMultiThreadedEigen(); + bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_); bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; @@ -1531,11 +684,11 @@ Status DotOpEmitter::EmitCallToRuntime() { } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { - CHECK_EQ(dot_.shape().dimensions_size(), 2); + CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2); const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); - const DotDimensionNumbers& dim_nums = dot_.dot_dimension_numbers(); + const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; return { /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)), @@ -1549,74 +702,6 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } -// Return whether the given shape is rank 2. -static bool IsRank2(const Shape& shape) { return shape.rank() == 2; } - -// In a gemm operation where output = lhs * rhs, check whether the given shapes -// are valid for the operation. -static bool AreValidGemmShapes( - const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, - const TargetMachineFeatures& target_machine_features) { - // The inputs and the output must - // 1) be matrices with no padding, and - // 2) have an allowed element type. - PrimitiveType output_primitive_type = output_shape.element_type(); - if (!(output_primitive_type == F64 || output_primitive_type == F32 || - output_primitive_type == F16)) { - return false; - } - - if (!(IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape))) { - return false; - } - - auto is_aligned = [&](const Shape& shape) { - return GetMinimumAlignmentForArray(shape, target_machine_features) >= - TargetMachineFeatures::kEigenExpectedTensorAlignment; - }; - - if (!is_aligned(lhs_shape) || !is_aligned(rhs_shape) || - !is_aligned(output_shape)) { - return false; - } - - return true; -} - -bool PotentiallyImplementedAsEigenDot( - const HloInstruction& hlo, - const TargetMachineFeatures& target_machine_features) { - // For certain types of Dot, we can call Eigen - if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - if (ShapeUtil::IsZeroElementArray(lhs_shape) || - ShapeUtil::IsZeroElementArray(rhs_shape)) { - return false; - } - - if (ProfitableToImplementDotInTiledLlvmIr(hlo)) { - return false; - } - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape(), - target_machine_features)) { - const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), - rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); - return true; - } - } - - return false; -} - // For vector-matrix dot products, it is always profitable to make the Rhs // column major. absl::optional ProfitableToMakeDotOperandColumnMajor( @@ -1655,16 +740,157 @@ absl::optional ProfitableToMakeDotOperandColumnMajor( return {}; } -bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { +namespace { +// Return whether the given shape is rank 2. +bool IsRank2(const Shape& shape) { return shape.rank() == 2; } + +bool IsSimpleLayout(const Layout& layout) { + return layout.tiles().empty() && layout.format() == DENSE; +} + +// In a gemm operation where output = lhs * rhs, check whether the given shapes +// are valid for the operation. +bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape, + const TargetMachineFeatures& target_machine_features) { + CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout())) + << lhs_shape.DebugString(); + CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout())) + << rhs_shape.DebugString(); + CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout())) + << output_shape.DebugString(); + + switch (output_shape.element_type()) { + case F64: + case F32: + case F16: + return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape); + default: + return false; + } +} + +bool IsAlignedGemm(const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) || + ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) { + return false; + } + + return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape, + dot_info.result_shape, target_machine_features); +} + +bool CanEmitTiledLlvmIrGemm( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + CHECK(IsAlignedGemm(dot_info, target_machine_features)); + + if (ShouldUseMultiThreadedEigen(config)) { + return false; + } + + int m = dot_info.result_shape.dimensions(0); + int k = dot_info.lhs_shape.dimensions( + dot_info.dim_nums.lhs_contracting_dimensions(0)); + int n = dot_info.result_shape.dimensions(1); + + if (!options::ForceEnableExperimentalLlvmIrGemm(config)) { + // TODO(sanjoy): We should make these numbers micro-arch specific. + bool small_gemm = + k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32)); + if (!small_gemm) { + return false; + } + } + + bool lhs_non_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 0; + bool rhs_non_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 1; + + if (lhs_non_canonical || rhs_non_canonical) { + return false; + } + + if (dot_info.result_shape.element_type() == F16) { + // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL + // adding this comment NFC. + return false; + } + + return true; +} + +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + PrimitiveType element_type = dot_info.result_shape.element_type(); // Any Matrix-Vector product of floating point or integral type, or // a transpose-dot fusion of the same can be lowered to a tiled LLVM // IR implementation. - const Shape& shape = dot.shape(); - return shape.dimensions_size() == 2 && - (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) && - (primitive_util::IsFloatingPointType(shape.element_type()) || - primitive_util::IsIntegralType(shape.element_type())); + if (dot_info.result_shape.dimensions_size() == 2 && + (dot_info.result_shape.dimensions(0) == 1 || + dot_info.result_shape.dimensions(1) == 1) && + (primitive_util::IsFloatingPointType(element_type) || + primitive_util::IsIntegralType(element_type))) { + return DotImplementationStrategy::kTiledLlvmIrGemv; + } + + if (IsAlignedGemm(dot_info, target_machine_features)) { + return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features) + ? DotImplementationStrategy::kTiledLlvmIrGemm + : DotImplementationStrategy::kEigen; + } + + return DotImplementationStrategy::kNaiveLlvmIr; } +} // namespace +bool DotImplementationCanHandleTranspose( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features) { + DotImplementationStrategy impl_strategy = + GetDotImplementationStrategy(dot_instr.parent()->parent()->config(), + DotInfo(dot_instr), target_machine_features); + + // TODO(sanjoy): This is not quite right, it should be `impl_strategy == + // kEigen || impl_strategy == kTiledLlvmIrGemv || impl_strategy == + // kNaiveLlvmIr` but I'll fix this in a later CL in the interest of keeping + // the CL adding this comment NFC. + return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm || + impl_strategy == DotImplementationStrategy::kEigen; +} + +bool DotOperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features) { + DotImplementationStrategy impl_strategy = + GetDotImplementationStrategy(dot_instr.parent()->parent()->config(), + DotInfo(dot_instr), target_machine_features); + + return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm || + impl_strategy == DotImplementationStrategy::kEigen; +} + +Status EmitDotOperation(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { + // This routine assumes that the dot operation is not in a parallelized + // enclosing computation. + CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty()); + + PrimitiveType type = target_array.GetShape().element_type(); + TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); + DotOpEmitter dot_emitter(DotInfo(dot), dot.name(), target_array, lhs_array, + rhs_array, addend_array, + executable_run_options_value, b, hlo_module_config, + target_machine_features); + return dot_emitter.Emit(); +} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 4c2041b556aa8bf8fe8fb8e0674c0f4f04f0acae..105bd3005c86d87443b2528eba7b0106ad70590e 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -30,9 +30,16 @@ limitations under the License. namespace xla { namespace cpu { +// Returns true if the two operands and the output of `dot_instr` must have row +// major layout. +bool DotOperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features); -bool PotentiallyImplementedAsEigenDot( - const HloInstruction& hlo, +// Returns true our lowering strategy for `dot_instr` can fold in transposes to +// the either of the inputs. +bool DotImplementationCanHandleTranspose( + const HloInstruction& dot_instr, const TargetMachineFeatures& target_machine_features); // Returns the index for an operand to `hlo` that should ideally be column @@ -41,129 +48,24 @@ bool PotentiallyImplementedAsEigenDot( absl::optional ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo); -// Returns true to indicate that we can generate a tiled LLVM IR implementation -// for |dot|. -bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot); - -// Helper class for emitting LLVM IR to perform the dot operation. -class DotOpEmitter { - public: - // Emit LLVM IR to perform the dot operation on lhs_array and rhs_array and - // place the result in target_array. IR is emitted at current insert point of - // the builder. Upon completion of the method, the insert point is set to the - // end of all instructions emitted for this operation. - // - // If `addend_array` is not nullptr then it must be an array of the same - // dimensions as the result, and the result is computed as `addend_array` + - // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported - // for Matrix-vector products. - static Status EmitDotOperation( - const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features); - - private: - DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features); - - // Emits the IR to perform the dot operation. - Status Emit(); - - // Emits instructions to perform a scalar dot product (a multiply of the - // LHS and RHS) and store the results in the target. - Status EmitScalarDot(); - - // Emit an LLVM IR implementation of the dot operation if we can. Returns - // true if an LLVM IR implementation was emitted. - bool EmitLlvmIrDotIfProfitable(); - - // Emits a call to the CPU runtime to perform the matrix multiply. - Status EmitCallToRuntime(); - - // Represents the dimensions of a matrix-matrix multiply operation. - struct MatMultDims { - // The number of rows in the LHS. - int64 m; - - // The number of columns in the LHS, which is also must be equal to the - // number of rows in the RHS. - int64 k; - - // The number of columns on the RHS. - int64 n; - - // True if the LHS matrix is column major. - bool lhs_column_major; - - // True if the LHS contraction dimension is not 1. - bool lhs_non_canonical; - - // True if the RHS matrix is column major. - bool rhs_column_major; - - // True if the RHS contraction dimension is not 0. - bool rhs_non_canonical; - - // True if the result matrix is column major. - bool target_column_major; - }; - - // Get the MatMultDims instance for the dot product this DotOpEmitter - // represents. Precondition: the dot is of rank 2 (and thus its operands are - // of rank 2 as well). - MatMultDims GetMatMultDims() const; - - bool EmitSmallGemmIfProfitable(const MatMultDims& mat_mult_dims); - - // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector - // registers. - int64 GetGemvTilingFactor() const { - const int64 kDefaultTilingFactor = 8; - return options::LlvmIrGemvTilingFactor(hlo_module_config_) - .value_or(kDefaultTilingFactor); - } - - std::tuple GetGemmTileSize() const { - // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz - // - // TODO(b/80093688): Tune for other architectures and centralize this - // information in one place. - const std::tuple kDefaultTileSize = - std::tuple(11, 9, 1); - return options::LlvmIrGemmTileSize(hlo_module_config_) - .value_or(kDefaultTileSize); - } - - // Returns true if we should use an experimental implementation of GEMM - // (general matrix matrix multiplication) if possible. - bool EnableExperimentalLlvmIrGemm() const { - return options::EnableExperimentalLlvmIrGemm(hlo_module_config_); - } - - // Returns true if we should call into multi-threaded Eigen routines. - bool ShouldUseMultiThreadedEigen() { - return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); - } - - const HloInstruction& dot_; - const llvm_ir::IrArray& target_array_; - const llvm_ir::IrArray& lhs_array_; - const llvm_ir::IrArray& rhs_array_; - const llvm_ir::IrArray* addend_array_; - llvm::Value* executable_run_options_value_; - llvm::IRBuilder<>* b_; - const HloModuleConfig& hlo_module_config_; - const TargetMachineFeatures& target_machine_features_; -}; - +// Emit LLVM IR to perform the dot operation on lhs_array and rhs_array and +// place the result in target_array. IR is emitted at current insert point of +// the builder. Upon completion of the method, the insert point is set to the +// end of all instructions emitted for this operation. +// +// If `addend_array` is not nullptr then it must be an array of the same +// dimensions as the result, and the result is computed as `addend_array` + +// dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported +// for Matrix-vector products. +Status EmitDotOperation(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..cc28918ed60a8086135846e2b9b1b9d75ec31ef6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +// ----------------------------------------------------------------------------- +// INTERNAL HEADER. +// +// This file exposes internal implementation details from dot_op_emitter.cc for +// unit tests. Please do not depend on this! +// +// ----------------------------------------------------------------------------- + +namespace xla { +namespace cpu { +namespace internal { + +// Represents a dot operation. We use this in lieu of an `HloInstruction` +// because we want to be able to create this for the "inner" dot operation in a +// batch dot, for which there is no separate HLO instruction. +struct DotInfo { + Shape lhs_shape; + Shape rhs_shape; + Shape result_shape; + DotDimensionNumbers dim_nums; + + explicit DotInfo(const HloInstruction& instr) { + CHECK_EQ(instr.opcode(), HloOpcode::kDot); + lhs_shape = instr.operand(0)->shape(); + rhs_shape = instr.operand(1)->shape(); + result_shape = instr.shape(); + dim_nums = instr.dot_dimension_numbers(); + } +}; + +// Dictates how a dot operation is implemented. +enum class DotImplementationStrategy { + // The dot operation is lowered into LLVM IR that implements a naive nested + // loop that computes the result one element at a time. This is our + // "fallback"; we don't really want this to kick in for any non-trival dot + // operation. + kNaiveLlvmIr, + + // The dot operation is lowered into LLVM IR that implements a tiled + // Matrix*Vector operation. This strategy also allows fusing in a bias add + // into the dot. The matrix can be row major or column major, both are + // supported. + kTiledLlvmIrGemv, + + // The dot operation is lowered into LLVM IR that implemetns a tiled + // Matrix*Matrix operation. No fusions are supported. The two inputs + // and the output have to be row major. + kTiledLlvmIrGemm, + + // The dot operation is lowered into a call into an Eigen routine. No fusions + // are supported today. The two inputs and the output have to be row major. + // However, we do allow transposing either the LHS or the RHS as part of the + // GEMM -- we expose this flexibility as flexibility in the contraction + // dimensions, but we can also see this as flexibility in the input layouts. + kEigen, +}; + +// Returns the implementation strategy for a dot with the configuration +// `dot_info`. +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features); +} // namespace internal +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 169d628923b05c43590782bcbebb0e5e539d83e3..91e369335455669c1ea16cedc930eb2c34b76abe 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,11 +24,9 @@ limitations under the License. #include #include +// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/platform/logging.h" -// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -70,6 +68,8 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/logging.h" namespace xla { @@ -970,10 +970,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { << llvm_ir::DumpToString(*target_array.GetBasePointer()); // Dot operation is complicated so we delegate to a helper class. - return DotOpEmitter::EmitDotOperation( - *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, - GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, - target_machine_features_); + return EmitDotOperation(*dot, target_array, lhs_array, rhs_array, + /*addend_array=*/nullptr, + GetExecutableRunOptionsArgument(), &b_, + hlo_module_config_, target_machine_features_); } StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( @@ -1399,7 +1399,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { int64 delta = 0; for (int64 i = 0; i < operand_shape.dimensions_size(); i++) { - if (reduced_dims.count(i)) { + if (reduced_dims.contains(i)) { delta++; } else { InsertOrDie(&unreduced_dim_map, i, i - delta); @@ -1412,7 +1412,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { for (int64 operand_dim_idx = 0; operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) { int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx); - if (!reduced_dims.count(operand_dim)) { + if (!reduced_dims.contains(operand_dim)) { if (FindOrDie(unreduced_dim_map, operand_dim) != result_shape.layout().minor_to_major(result_dim_idx++)) { return false; @@ -1709,10 +1709,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( vectorization_factor_in_bytes / ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()); - bool is_reduction_over_minor_dimension = - std::find(dimensions.begin(), dimensions.end(), - LayoutUtil::Minor(arg->shape().layout(), 0)) != - dimensions.end(); + bool is_reduction_over_minor_dimension = absl::c_linear_search( + dimensions, LayoutUtil::Minor(arg->shape().layout(), 0)); unsigned element_alignment = tensorflow::MathUtil::GCD( ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), @@ -1990,7 +1988,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // The memcpy will copy elements that are logically this shape (allowed to be // scalar). const Shape logical_element_shape = ShapeUtil::FilterDimensions( - [&inner_dims](int64 dim) -> bool { return inner_dims.count(dim); }, + [&inner_dims](int64 dim) { return inner_dims.contains(dim); }, operand->shape()); const int64 primitive_elements_per_logical_element = @@ -2205,10 +2203,10 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { llvm_ir::IrArray addend_array( GetIrArrayFor(fusion->operand(addend_param_number))); - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, target_array, lhs_array, rhs_array, &addend_array, - GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, - target_machine_features_)); + TF_RETURN_IF_ERROR( + EmitDotOperation(*dot, target_array, lhs_array, rhs_array, + &addend_array, GetExecutableRunOptionsArgument(), &b_, + hlo_module_config_, target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); @@ -2401,8 +2399,7 @@ StatusOr IrEmitter::EmitFastConcatenate( int64 concat_dim = concatenate->dimensions(0); const Layout& output_layout = output_shape.layout(); auto output_min2maj = LayoutUtil::MinorToMajor(output_layout); - auto concat_dim_layout_itr = - std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim); + auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim); std::vector inner_dims(output_min2maj.begin(), concat_dim_layout_itr); std::vector outer_dims(std::next(concat_dim_layout_itr), @@ -2956,8 +2953,7 @@ Status IrEmitter::ElementTypesSameAndSupported( TF_RET_CHECK(!operands.empty()); PrimitiveType primitive_type = operands[0]->shape().element_type(); - if (std::find(supported_types.begin(), supported_types.end(), - primitive_type) == supported_types.end()) { + if (!absl::c_linear_search(supported_types, primitive_type)) { return Unimplemented("unsupported operand type %s in op %s", PrimitiveType_Name(primitive_type), HloOpcodeString(instruction.opcode())); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index db76de4bb2b8ed568bf2557a30fa216d0cbe518d..a6fb11dcbf9bb201ba8837866e2f509c48bfd061 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -448,7 +448,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, computation_to_profile_idx_; // Maps HLOs to Values emitted for them. - std::unordered_map emitted_value_; + absl::flat_hash_map emitted_value_; llvm_ir::AliasAnalysis alias_analysis_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 3b423f639137149f7efdee91a02bbcf01d1cc6e1..6121d1ca9a5c785cedd947200d3e7e320aa06bc2 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -146,8 +146,6 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || - PotentiallyImplementedAsEigenDot(*instruction, - target_machine_features_) || (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || instruction->shape().IsTuple()) { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 1ed743afc30af7c7ff38c7d2a738f2e376270952..1f7204e67a413efabd34cd7d88ced4c82ee7a5df 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -20,6 +20,10 @@ limitations under the License. #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + using tensorflow::int32; using tensorflow::int64; diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb6c44b70ab34d0a294880b5de4fe0b3ba5e19e5 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc @@ -0,0 +1,1014 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" + +#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +namespace cpu { +namespace { + +using tensorflow::int64; + +// Provides tiled access to an in-memory rank 2 array. +class MemoryTile { + public: + // Constructs a MemoryTile that can operate on tiles consisting of + // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at + // `major_dim_offset` in the major dimension. The tile size along the minor + // dimension is the vector size, and that is implicitly determined by `vsl`. + MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b, + llvm::Value* matrix, int64 matrix_size_along_minor_dim, + llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) + : vsl_(vsl), b_(b) { + pointers_.reserve(tile_size_along_major_dim); + for (int64 i = 0; i < tile_size_along_major_dim; i++) { + llvm::Value* total_offset = + b->CreateMul(b->getInt64(matrix_size_along_minor_dim), + b->CreateAdd(b->getInt64(i), major_dim_offset)); + pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); + } + } + + // Load a tile consisting of `tile_size_along_major_dim` vectors from position + // {major: `major_dim_offset`, minor: `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector LoadTile(llvm::Value* minor_dim_offset) const { + std::vector result; + result.reserve(pointers_.size()); + for (const auto& pointer : pointers_) { + result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); + } + return result; + } + + // Stores `tile` to position {major: `major_dim_offset`, minor: + // `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + void StoreTile(absl::Span tile, + llvm::Value* minor_dim_offset) const { + CHECK_EQ(tile.size(), pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); + } + } + + // Loads a tile of size [`tile_size_along_major_dim`, + // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, + // minor: `minor_dim_offset`} and then broadcasts each element into a vector + // of size vsl_.vector_size(). The (i,j)'th element of the return value is + // the (i,j)'th element in the tile broadcasted into an LLVM vector. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector> LoadBroadcastTile( + llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { + std::vector> result; + result.resize(pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + for (int64 j = 0; j < tile_size_along_middle_dim; j++) { + result[i].push_back(vsl_->LoadBroadcast( + pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j)))); + } + } + return result; + } + + private: + VectorSupportLibrary* vsl_; + llvm::IRBuilder<>* b_; + std::vector pointers_; +}; + +// The base class for the classes representing the GEMV emitter configurations. +// +// The IR emitted (modulo the LLVM values representing the input and output +// buffers) by the row major and column major GEMV emitters should be a function +// of their configuration. This is important because their configuration is +// used as a key to cache the generated IR. +class GemvConfig { + public: + // Mixin for convenience. + template + struct User { + public: + PrimitiveType scalar_type() const { + return derived().config().scalar_type(); + } + int64 tile_rows() const { return derived().config().tile_rows(); } + int64 tile_cols() const { return derived().config().tile_cols(); } + int64 m() const { return derived().config().m(); } + int64 k() const { return derived().config().k(); } + int64 has_addend() const { return derived().config().has_addend(); } + + private: + const T& derived() const { return *static_cast(this); } + }; + + PrimitiveType scalar_type() const { return scalar_type_; } + int64 tile_rows() const { return tile_rows_; } + int64 tile_cols() const { return tile_cols_; } + int64 m() const { return m_; } + int64 k() const { return k_; } + bool has_addend() const { return has_addend_; } + + string GetCacheKey() const { + return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", + tile_rows(), "_", tile_cols(), "_", m(), "_", k(), + has_addend() ? "_with_addend" : ""); + } + + protected: + explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, bool has_addend) + : name_(std::move(name)), + scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + has_addend_(has_addend) {} + + private: + string name_; + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + bool has_addend_; +}; + +// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +--+--+--+--+ +// |M00|M10|M20|M30| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M03|M13|M23|M33| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// +// (Legend: rows are horizontal and columns are vertical; and each column is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is from the column major left matrix. +// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] +// vector loaded from the RHS vector. +// +// As we iterate through the column dimension, we compute the change to the +// result vector by an elementwise multiplication between the two tiles above +// followed by a reduction along the major dimension: +// +// +-----------------------------------+ +// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | +// +-----------------------------------+ +// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | +// Result[R:R+4] += +-----------------------------------+ +// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | +// +-----------------------------------+ +// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | +// +-----------------------------------+ +// +// Where R is the starting row for the tile. +// +// We have an inner epilogue loop to deal with the "C" submatrix and an outer +// epilogue loop to deal with the B,D submarix. +// +// TODO(sanjoy): We should investigate if using gather loads and scatter stores +// can be used here have the same inner loop for both column-major and row-major +// matrix-vector products. +class ColumnMajorMatrixVectorProductEmitter + : public GemvConfig::User { + public: + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"col_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, + llvm::IRBuilder<>* b) + : config_(config), + lhs_(lhs), + rhs_(rhs), + addend_(addend), + result_(result), + b_(b), + ksl_(b_), + vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") { + CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); + CHECK(!has_addend() || addend != nullptr); + } + + void Emit(); + + const Config& config() const { return config_; } + + private: + void EmitOuterLoopBody(llvm::Value* column, int64 column_count, + bool is_first_column); + + MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m(), + /*major_dim_offset=*/column_start, + /*tile_size_along_major_dim=*/column_count); + } + + // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous + // sequence of `count` values, each one broadcasted to the vector width. + std::vector LoadRhsTile(llvm::Value* offset, int64 count) { + llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); + std::vector result; + result.reserve(count); + for (int64 i = 0; i < count; i++) { + result.push_back(vsl_.LoadBroadcast(base_pointer, i)); + } + return result; + } + + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, + const std::vector& rhs_tile, + int64 columns, bool is_first_column); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, + bool is_first_tiled_column); + + Config config_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* addend_; + llvm::Value* result_; + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( + llvm::Value* column, int64 column_count, bool is_first_column) { + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, + /*column_count=*/column_count); + + std::vector rhs_tile = + LoadRhsTile(column, /*count=*/column_count); + EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, + /*columns=*/column_count, is_first_column); + EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); +} + +void ColumnMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 column_remainder = k() % tile_cols(); + int64 column_limit = k() - column_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); + + if (column_remainder != 0) { + EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, + column_limit == 0); + } +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, + int64 columns, bool is_first_column) { + int64 row_limit = m() - (m() % tile_rows()); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows(), [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { + int64 row_start = m() - (m() % tile_rows()); + if (row_start == m()) { + return; + } + + llvm::Value* columns_llvm = b_->getInt64(columns); + + // for (col = current_tile_col; col < (columns + current_tile_col); col++) + // for (row = row_start, row < m_; row++) { + // result[row] += lhs[row, col] * rhs[col] + // // Also take into account that if col is 0 then result[row] is not + // // initialized. + // } + + ksl_.For( + "dot.inner.epilg.outer", /*start=*/current_tile_col, + /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), + /*step=*/1, /*peel_first_iteration=*/false, + [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { + llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); + llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For( + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), + /*step=*/1, [&](llvm::Value* scalar_row) { + llvm::Value* product = vsl_.Mul( + vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); + llvm::Value* setting_result_first_time = b_->CreateAnd( + is_first_scalar_col, b_->getInt1(is_first_tiled_column)); + ksl_.If( + setting_result_first_time, + /*true_block_generator=*/ + [&]() { + if (addend_) { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), + product), + result_, scalar_row); + } else { + vsl_.StoreScalar(product, result_, scalar_row); + } + }, + /*false_block_generator=*/ + [&]() { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), + result_, scalar_row); + }); + }); + }); +} + +// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +// |M00|M10|M20|M30| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| +// +---+---+---+---+ +// |M03|M13|M23|M33| +// +---+---+---+---+ +// +// (Legend: rows are horizontal and columns are vertical; and each row is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is loaded from the row major left matrix. +// b. The right vector is loaded from the RHS vector. +// +// We keep 4 vector accumulators accumulating the following four vector +// expressions as we iterate over the row dimension: +// +// +------+------+------+------+ +// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) +// +------+------+------+------+ +// +// In the end we do a horizontal reduction over these 4 vector accumulators to +// get 4 values in the result vector. +// +// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer +// epilogue loop to deal with the C,D submatrix. +class RowMajorMatrixVectorProductEmitter + : public GemvConfig::User { + public: + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"row_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b) + : config_(config), + lhs_(lhs), + rhs_(rhs), + addend_(addend), + result_(result), + b_(b), + ksl_(b_), + vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") { + CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); + CHECK(!has_addend() || addend != nullptr); + } + + void Emit(); + + const Config& config() const { return config_; } + + private: + MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k(), + /*major_dim_offset=*/row_start, + /*tile_size_along_major_dim=*/row_count); + } + + void EmitOuterLoopBody(llvm::Value* row, int64 row_count); + + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, + std::vector* vector_accumulators); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators); + + Config config_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* addend_; + llvm::Value* result_; + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, + int64 row_count) { + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, + /*row_count=*/row_count); + std::vector vector_accumulators; + std::vector scalar_accumulators; + for (int i = 0; i < row_count; i++) { + vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); + scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); + } + EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, + &vector_accumulators); + EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, + &scalar_accumulators); + + std::vector accumulator_values; + std::transform( + vector_accumulators.begin(), vector_accumulators.end(), + std::back_inserter(accumulator_values), + [](const VectorVariable& vector_var) { return vector_var.Get(); }); + + std::vector horizontal_sums; + if (row_count == vsl_.vector_size()) { + if (addend_) { + horizontal_sums = vsl_.ComputeHorizontalSums( + std::move(accumulator_values), vsl_.LoadVector(addend_, row)); + } else { + horizontal_sums = + vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + } else { + horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + + for (int i = 0; i < row_count; i++) { + llvm::Value* result_value = + vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); + llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row); + if (addend_ && row_count != vsl_.vector_size()) { + result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); + } + vsl_.StoreScalar(result_value, result_, offset); + } +} + +void RowMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 row_remainder = m() % tile_rows(); + int64 row_limit = m() - row_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); + + if (row_remainder != 0) { + EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); + } +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + MemoryTile* lhs_memory_tile, int64 rows, + std::vector* vector_accumulators) { + int64 column_limit = k() - (k() % tile_cols()); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set( + vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators) { + int64 column_start = k() - (k() % tile_cols()); + if (column_start == k()) { + return; + } + + for (int r = 0; r < rows; r++) { + llvm::Value* total_offset = b_->CreateMul( + b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); + } +} + +// This class implements a tiled matrix multiplication algorithm, intended for +// multiplying small matrices that don't need cache tiling. +// +// In the future this can be used as the innermost GEBP loop in a GEMM kernel as +// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of +// high-performance matrix multiplication." ACM Transactions on Mathematical +// Software (TOMS) 34.3 (2008): 12.". +// +// This only supports canonical dot operations (i.e. where the lhs contraction +// dimension is 1 and the rhs contraction dimension is 0) over row major +// matrices. +class TiledSmallGemmEmitter { + public: + // Describe the dimensions of the kernel. + class Dimensions { + public: + explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + + int64 m() const { return m_; } + int64 k() const { return k_; } + int64 n() const { return n_; } + + string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } + + private: + const int64 m_; + const int64 k_; + const int64 n_; + }; + + // Represents the configuration of the emitter. The LLVM IR emitted by the + // emitter, modulo the LLVM values holding the input and output buffers, must + // be a function of the instance of `Config` passed to it. + // + // `dims` holds the matrix multiplication dimensions. + // + // `max_vectorization_width` is the maximum vector width (i.e. the width of + // the largest vector register we will use). This can be larger than the + // largest vector register supported by the machine -- LLVM will legalize + // these large vector widths into legally sized vectors. + // + // `max_vector_count` is the maximum number of vectors of size + // `max_vectorization_width` that we will attempt to process at once. + // + // `min_vectorization_width` is the smallest vector width the emitter will use + // -- below that it will devolve to using a scalar loop. + // + // The innermost reduction loop executes the matrix multiply in tiles of size + // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, + // ] in the RHS. + class Config { + public: + explicit Config(PrimitiveType scalar_type, Dimensions dims, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k) + : scalar_type_(scalar_type), + dims_(dims), + max_vectorization_width_(max_vectorization_width), + max_vector_count_(max_vector_count), + min_vectorization_width_(min_vectorization_width), + tile_size_m_(tile_size_m), + tile_size_k_(tile_size_k) {} + + string GetCacheKey() const { + return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", + dims().ToString(), "_", max_vectorization_width(), + "_", min_vectorization_width(), "_", tile_size_m(), + "_", tile_size_k()); + } + + PrimitiveType scalar_type() const { return scalar_type_; } + Dimensions dims() const { return dims_; } + int64 max_vectorization_width() const { return max_vectorization_width_; } + int64 max_vector_count() const { return max_vector_count_; } + int64 min_vectorization_width() const { return min_vectorization_width_; } + + int64 tile_size_m() const { return tile_size_m_; } + int64 tile_size_k() const { return tile_size_k_; } + + private: + PrimitiveType scalar_type_; + Dimensions dims_; + int64 max_vectorization_width_; + int64 max_vector_count_; + int64 min_vectorization_width_; + int64 tile_size_m_; + int64 tile_size_k_; + }; + + // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* b) + : lhs_(lhs), + rhs_(rhs), + result_(result), + config_(config), + b_(b), + ksl_(b_) { + CHECK(max_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(max_vectorization_width()))); + CHECK_GT(max_vector_count(), 0); + CHECK(min_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(min_vectorization_width()))); + CHECK_GE(max_vectorization_width(), min_vectorization_width()); + CHECK_GT(tile_size_k(), 0); + } + + void Emit(); + + private: + // The HandleResiduesOnX helpers split the iteration space for dimension X + // into a multiple of the tile size on dimension X and an epilogue. These + // helpers ultimately call into `EmitTiledGemm` for emitting the + // tiled GEMM kernel. + + void HandleResiduesOnN(); + void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end); + + // This emits a tiled GEMM kernel. For a detailed description see the comment + // on the implementation. + void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); + + llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); } + + Config config() const { return config_; } + Dimensions dims() const { return config().dims(); } + + int64 max_vectorization_width() const { + return config().max_vectorization_width(); + } + int64 max_vector_count() const { return config().max_vector_count(); } + int64 min_vectorization_width() const { + return config().min_vectorization_width(); + } + int64 tile_size_m() const { return config().tile_size_m(); } + int64 tile_size_k() const { return config().tile_size_k(); } + PrimitiveType scalar_type() const { return config().scalar_type(); } + + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + Config config_; + + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; +}; + +void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } + +void TiledSmallGemmEmitter::HandleResiduesOnN() { + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. + + int64 current_vectorization_width = + max_vector_count() * max_vectorization_width(); + int64 current_vector_count = max_vector_count(); + + int64 n_start = 0; + while (n_start != dims().n() && + current_vectorization_width >= min_vectorization_width()) { + int64 n_end = dims().n() - (dims().n() % current_vectorization_width); + if (n_start != n_end) { + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, + "gemm"); + HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); + n_start = n_end; + } + if (current_vector_count == 1) { + current_vectorization_width /= 2; + } else { + current_vector_count--; + current_vectorization_width = + current_vector_count * max_vectorization_width(); + } + } + + if (n_start != dims().n()) { + VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); + ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); + HandleResiduesOnK(&vsl, n_i, n_i_next); + }); + } +} + +void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { + int64 k_start = 0; + int64 k_end = dims().k() - (dims().k() % tile_size_k()); + if (k_end != k_start) { + HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), + n_start, n_end); + k_start = k_end; + } + + if (k_start != dims().k()) { + HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), + GetInt64(dims().k()), n_start, n_end); + } +} + +void TiledSmallGemmEmitter::HandleResiduesOnM( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { + const int64 m_end = dims().m() - dims().m() % tile_size_m(); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), + GetInt64(0), GetInt64(m_end)); + + if (m_end != dims().m()) { + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); + } +} + +// The loop structure is: +// +// Iterate over dimension M as m: +// Iterate over dimension N as n: +// Iterate over dimension K as k: +// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) +// +// I.e. a just a tiled version of a "naive" GEMM. +// +// The tiling scheme is as follows: +// +// Let the LHS be: +// +// +----+----+----+ +// | a0 | b0 | c0 | . +// +----+----+----+ . +// | a1 | b1 | c1 | . +// +----+----+----+ +// .. .. +// +// and the RHS be: +// +// +----+----+----+----+ +// | p0 | p1 | p2 | p3 | . +// +----+----+----+----+ . +// | q0 | q1 | q2 | q3 | . +// +----+----+----+----+ +// | r0 | r1 | r2 | r3 | . +// +----+----+----+----+ . +// ...... ...... +// +// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted +// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] +// matrix that we can increment the result matrix by. +// +// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank +// 3 array, L, of dimension [2,3,4]: +// +// L[0,_,_] * L[1,_,_] +// * +// +----+----+----+----+ * +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | +// +----+----+----+----+ * +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | +// +----+----+----+----+ * +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | +// +----+----+----+----+ * +----+----+----+----+ +// +// +// Then we FMA L[0,_,_] with the RHS to get the first row of the result and +// L[1,_,_] with the RHS to get the second row of the result. For example, +// L[0,_,_] is computed as: +// +// +----+----+----+----+ +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | +// +----+----+----+----+ +----+----+----+----+ +// +// to get: +// +// +-------------------+-------------------+-------------------+--------- +// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... +// +-------------------+-------------------+-------------------+--------- +void TiledSmallGemmEmitter::EmitTiledGemm( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { + ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile(vsl, b_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + ksl_.For( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i)); + ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, + tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); + std::vector rhs_tile = rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = result_tile_var.Get(); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } + } + result_tile_var.Set(result_tile); + }); + + result_memory_tile.StoreTile(result_tile_var.Get(), n_i); + }); + }); +} + +} // namespace + +void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + RowMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, + /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, + rhs, addend, result, + [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result) { + RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, + result, b); + emitter.Emit(); + }); +} + +void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + ColumnMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, + /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, + rhs, addend, result, + [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result) { + ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, + result, b); + emitter.Emit(); + }); +} + +void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + TiledSmallGemmEmitter::Config config( + /*scalar_type=*/scalar_type, + TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + /*max_vectorization_width=*/max_vectorization_width, + /*max_vector_count=*/max_vector_count, + /*min_vectorization_width=*/min_vectorization_width, + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, + rhs, result, + [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) { + TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, + /*rhs=*/rhs, + /*result=*/result, b); + small_gemm_emitter.Emit(); + }); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..0a82326cc3704bce8c122261383249c60eda1f3a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace cpu { + +// These routines emit LLVM IR implementing tiled GEMM and GEMV routines. + +void EmitRowMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, + tensorflow::int64 tile_cols, tensorflow::int64 m, + tensorflow::int64 k, llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, llvm::Value* result, + llvm::IRBuilder<>* b, bool enable_fast_math, + bool optimize_for_size); + +void EmitColumnMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, + tensorflow::int64 tile_cols, tensorflow::int64 m, + tensorflow::int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size); + +void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m, + tensorflow::int64 k, tensorflow::int64 n, + tensorflow::int64 max_vectorization_width, + tensorflow::int64 max_vector_count, + tensorflow::int64 min_vectorization_width, + tensorflow::int64 tile_size_m, tensorflow::int64 tile_size_k, + llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* b, bool enable_fast_math, + bool optimize_for_size); + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index b2ba2617902104bfea06713332fa1c2aedea536d..855424067d26d4968270e5f24b11f5a053b70a55 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -156,29 +158,187 @@ Status DecomposeBatchDot(HloInstruction* dot) { return computation->ReplaceInstruction(dot, new_dot); } +// Convert a dot into a canonical form where non-contracting and contracting +// dimensions are reshaped together and batch dimensions are the most major +// dimensions. The requires transposing and reshapes the lhs and rhs and +// reshaping the output batch to the original shape. +Status CanonicalizeDot(HloInstruction* original_dot) { + auto computation = original_dot->parent(); + const auto& original_dnums = original_dot->dot_dimension_numbers(); + const int64 num_batch_dims = original_dnums.lhs_batch_dimensions_size(); + const int64 num_contracting_dims = + original_dnums.lhs_contracting_dimensions_size(); + + const auto& lhs_shape = original_dot->operand(0)->shape(); + const int64 lhs_rank = lhs_shape.rank(); + const int64 num_lhs_non_contracting_dims = + lhs_rank - num_batch_dims - num_contracting_dims; + + std::vector lhs_non_contracting_dims; + lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims); + int64 lhs_contracting_size = 1; + int64 lhs_non_contracting_size = 1; + std::vector batch_dim_sizes; + batch_dim_sizes.reserve(num_batch_dims); + for (int64 i = 0; i < lhs_rank; ++i) { + if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) { + lhs_contracting_size *= lhs_shape.dimensions(i); + } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(), + i)) { + batch_dim_sizes.push_back(lhs_shape.dimensions(i)); + } else { + lhs_non_contracting_dims.push_back(i); + lhs_non_contracting_size *= lhs_shape.dimensions(i); + } + } + // The canonical form of the lhs is + // [BatchDims, NonContractingDims, ContractingsDims] + std::vector lhs_transpose; + lhs_transpose.reserve(lhs_rank); + lhs_transpose.insert(lhs_transpose.end(), + original_dnums.lhs_batch_dimensions().begin(), + original_dnums.lhs_batch_dimensions().end()); + lhs_transpose.insert(lhs_transpose.end(), lhs_non_contracting_dims.begin(), + lhs_non_contracting_dims.end()); + lhs_transpose.insert(lhs_transpose.end(), + original_dnums.lhs_contracting_dimensions().begin(), + original_dnums.lhs_contracting_dimensions().end()); + HloInstruction* transposed_lhs = + computation->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose), + lhs_shape), + original_dot->mutable_operand(0), lhs_transpose)); + std::vector lhs_reshape_dims = batch_dim_sizes; + lhs_reshape_dims.push_back(lhs_non_contracting_size); + lhs_reshape_dims.push_back(lhs_contracting_size); + // Reshape the contracting and non-contracting dimensions together. + HloInstruction* reshaped_lhs = + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims), + transposed_lhs)); + + const auto& rhs_shape = original_dot->operand(1)->shape(); + const int64 rhs_rank = rhs_shape.rank(); + const int64 num_rhs_non_contracting_dims = + rhs_rank - num_batch_dims - num_contracting_dims; + std::vector rhs_non_contracting_dims; + rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims); + int64 rhs_non_contracting_size = 1; + int64 rhs_contracting_size = 1; + for (int64 i = 0; i < rhs_rank; ++i) { + if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) { + rhs_contracting_size *= rhs_shape.dimensions(i); + } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(), + i)) { + rhs_non_contracting_dims.push_back(i); + rhs_non_contracting_size *= rhs_shape.dimensions(i); + } + } + + // The canonical form of the rhs is + // [BatchDims, ContractingsDims, NonContractingDims] + std::vector rhs_transpose; + rhs_transpose.reserve(rhs_rank); + rhs_transpose.insert(rhs_transpose.end(), + original_dnums.rhs_batch_dimensions().begin(), + original_dnums.rhs_batch_dimensions().end()); + rhs_transpose.insert(rhs_transpose.end(), + original_dnums.rhs_contracting_dimensions().begin(), + original_dnums.rhs_contracting_dimensions().end()); + rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(), + rhs_non_contracting_dims.end()); + HloInstruction* transposed_rhs = + computation->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose), + rhs_shape), + original_dot->mutable_operand(1), rhs_transpose)); + + std::vector rhs_reshape_dims = batch_dim_sizes; + rhs_reshape_dims.push_back(rhs_contracting_size); + rhs_reshape_dims.push_back(rhs_non_contracting_size); + // Reshape the contracting and non-contracting dimensions together. + HloInstruction* reshaped_rhs = + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims), + transposed_rhs)); + + std::vector dot_dims = batch_dim_sizes; + dot_dims.push_back(lhs_non_contracting_size); + dot_dims.push_back(rhs_non_contracting_size); + + DotDimensionNumbers dot_dnums; + for (int64 i = 0; i < num_batch_dims; ++i) { + dot_dnums.add_lhs_batch_dimensions(i); + dot_dnums.add_rhs_batch_dimensions(i); + } + dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1); + dot_dnums.add_rhs_contracting_dimensions(num_batch_dims); + + HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims), + reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config())); + + return computation->ReplaceInstruction( + original_dot, computation->AddInstruction(HloInstruction::CreateReshape( + original_dot->shape(), dot))); +} + } // namespace StatusOr DotDecomposer::Run(HloModule* module) { XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString()); - // Gather all batch Dot operations. - std::vector batch_dots; + // Gather all Non-canonical Dot operations. + std::vector non_canonical_dots; for (auto* computation : module->MakeNonfusionComputations()) { for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kDot) { continue; } const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) { - batch_dots.push_back(instruction); + // A dot it not canonical if there are more than one contracting + // dimension. + if (dnums.lhs_contracting_dimensions_size() > 1) { + non_canonical_dots.push_back(instruction); + continue; + } + if (dnums.lhs_batch_dimensions().empty()) { + continue; + } + std::vector canonical_batch_dims( + dnums.lhs_batch_dimensions_size()); + absl::c_iota(canonical_batch_dims, 0); + if (!absl::c_equal(dnums.lhs_batch_dimensions(), canonical_batch_dims) || + !absl::c_equal(dnums.rhs_batch_dimensions(), canonical_batch_dims)) { + non_canonical_dots.push_back(instruction); } } } - // Decompose each batch Dot in 'batch_dots'. bool changed = false; - for (auto* dot : batch_dots) { - TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + for (auto* dot : non_canonical_dots) { + TF_RETURN_IF_ERROR(CanonicalizeDot(dot)); changed = true; } + + if (decompose_batch_dot_) { + std::vector batch_dots; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); + if (!dnums.lhs_batch_dimensions().empty()) { + batch_dots.push_back(instruction); + } + } + } + // Decompose each batch Dot in 'batch_dots'. + + for (auto* dot : batch_dots) { + TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + changed = true; + } + } XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index f2b7e9a186135d6246d2dc6c5fb3fa0b70145d10..2b158d7a6ec510ce4cbc56bddc5cca71ac4f14f4 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -433,7 +433,7 @@ Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension( /* static */ StatusOr DynamicDimensionInference::Run( HloModule* module) { - VLOG(0) << "Param Config " << module->dynamic_parameter_binding().ToString(); + VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString(); DynamicDimensionInference inference(module); TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions()); return inference; diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter.cc b/tensorflow/compiler/xla/service/dynamic_index_splitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..ede4a5afdb7d02b31b9d0839fc8716f9b814e544 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter.cc @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +StatusOr DynamicIndexSplitter::Run(HloModule* module) { + bool changed = false; + + std::vector computations = + module->MakeNonfusionComputations(); + for (HloComputation* computation : computations) { + for (HloInstruction* dynamic_op : computation->MakeInstructionPostOrder()) { + switch (dynamic_op->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + break; + default: + continue; + } + auto parent = dynamic_op->parent(); + bool is_update = dynamic_op->opcode() == HloOpcode::kDynamicUpdateSlice; + int64 index_operand_number = Cast(dynamic_op) + ->first_index_operand_number(); + auto index_operand = dynamic_op->mutable_operand(index_operand_number); + if (ShapeUtil::IsScalar(index_operand->shape())) { + // This DS/DUS already uses scalar indices. + continue; + } + TF_RET_CHECK(index_operand->shape().rank() == 1); + int64 num_indices = index_operand->shape().dimensions(0); + if (num_indices == 0) { + // If the operand dimension is 0, directly replace R0 DS/DUS with the + // operand (for DS) or update (for DUS). + if (is_update) { + TF_CHECK_OK(parent->ReplaceInstruction( + dynamic_op, dynamic_op->mutable_operand(1))); + } else { + TF_CHECK_OK(parent->ReplaceInstruction( + dynamic_op, dynamic_op->mutable_operand(0))); + } + changed = true; + continue; + } + auto index_element_type = index_operand->shape().element_type(); + std::vector index_array; + for (int64 dim = 0; dim < num_indices; ++dim) { + auto slice = parent->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(index_element_type, {1}), index_operand, {dim}, + {dim + 1}, {1})); + auto bitcast = parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(index_element_type, {}), slice)); + index_array.push_back(bitcast); + } + auto new_dynamic_op = + is_update + ? HloInstruction::CreateDynamicUpdateSlice( + dynamic_op->shape(), dynamic_op->mutable_operand(0), + dynamic_op->mutable_operand(1), absl::MakeSpan(index_array)) + : HloInstruction::CreateDynamicSlice( + dynamic_op->shape(), dynamic_op->mutable_operand(0), + absl::MakeSpan(index_array), + dynamic_op->dynamic_slice_sizes()); + TF_CHECK_OK(parent->ReplaceWithNewInstruction(dynamic_op, + std::move(new_dynamic_op))); + changed = true; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter.h b/tensorflow/compiler/xla/service/dynamic_index_splitter.h new file mode 100644 index 0000000000000000000000000000000000000000..3c12e3a4af287ad2272a08ba54cd99c2cad9d451 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Convert R1 index operands to DynamicSlice and DynamicUpdateSlice ops into +// separate scalars. +class DynamicIndexSplitter : public HloModulePass { + public: + DynamicIndexSplitter() = default; + absl::string_view name() const override { return "dynamic-index-splitter"; } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc b/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..98029d1faff7d669730f6b66e38fcefece70f0eb --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; +class DynamicIndexSplitterTest : public HloTestBase {}; + +TEST_F(DynamicIndexSplitterTest, DynamicSlice) { + const char* const kDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY entry (operand: s32[4,5,6], indices: s32[3]) -> s32[1,1,1] { + operand = s32[4,5,6] parameter(0) + indices = s32[3] parameter(1) + ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, indices), dynamic_slice_sizes={1,1,1} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice(op::Parameter(0), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))))); + + for (int i = 0; i < 3; ++i) { + const HloInstruction* slice = module->entry_computation() + ->root_instruction() + ->operand(i + 1) + ->operand(0); + EXPECT_EQ(slice->slice_starts(0), i); + EXPECT_EQ(slice->slice_limits(0), i + 1); + } +} + +TEST_F(DynamicIndexSplitterTest, DynamicUpdateSlice) { + const char* const kDynamicUpdateSlice = R"( + HloModule DynamicUpdatedSlice_module + + ENTRY entry (operand: s32[4,5,6], indices: s32[3], update: s32[1,1,1]) -> s32[4,5,6] { + operand = s32[4,5,6] parameter(0) + indices = s32[3] parameter(1) + update = s32[1,1,1] parameter(2) + ROOT dynamic-update-slice = s32[4,5,6] dynamic-update-slice(operand, update, indices) + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kDynamicUpdateSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::DynamicUpdateSlice(op::Parameter(0), op::Parameter(2), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))))); + + for (int i = 0; i < 3; ++i) { + const HloInstruction* slice = module->entry_computation() + ->root_instruction() + ->operand(i + 2) + ->operand(0); + EXPECT_EQ(slice->slice_starts(0), i); + EXPECT_EQ(slice->slice_limits(0), i + 1); + } +} + +TEST_F(DynamicIndexSplitterTest, AlreadyScalar) { + const char* const kDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY entry (operand: s32[4,5,6], index.0: s32[], index.1: s32[], index.2: s32[]) -> s32[1,1,1] { + operand = s32[4,5,6] parameter(0) + index.0 = s32[] parameter(1) + index.1 = s32[] parameter(2) + index.2 = s32[] parameter(3) + ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, index.0, index.1, index.2), dynamic_slice_sizes={1,1,1} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_FALSE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice(op::Parameter(0), op::Parameter(1), + op::Parameter(2), op::Parameter(3))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc new file mode 100644 index 0000000000000000000000000000000000000000..4db280f817141bd52e3a5b9564600a618f81aeac --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/dynamic_padder.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +// ChooseIdentityValue looks at the instruction and returns a identity value +// which, when padded, doesn't change the result of the instruction. +// +// nullopt is returned if padding doesn't need to be reset. +StatusOr ChooseIdentityValue(HloInstruction* inst) { + HloComputation* comp = inst->parent(); + // Padding on elementwise operation doesn't affect the result of the effective + // data. + if (inst->IsElementwise()) { + return nullptr; + } + + switch (inst->opcode()) { + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: { + // Because of the way we do reduce, we already require the `init` operand + // of hlo reduce instruction to be identity value. Here we reuse the + // operand. + return inst->mutable_operand(1); + } + + case HloOpcode::kConvolution: + case HloOpcode::kDot: { + // Use 0 as padding value for convolution and dot. + PrimitiveType ptype = inst->shape().element_type(); + return comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(ptype))); + } + + case HloOpcode::kPad: { + return inst->mutable_operand(1); + } + case HloOpcode::kParameter: + case HloOpcode::kGetDimensionSize: + case HloOpcode::kReshape: + case HloOpcode::kTuple: + case HloOpcode::kAllReduce: + case HloOpcode::kBroadcast: + return nullptr; + default: + return UnimplementedStrCat("Unimplimented padding for instruction: ", + inst->ToString()); + } +} + +} // namespace + +StatusOr DynamicPadder::Run(HloModule* module) { + bool changed = false; + VLOG(2) << "Pre DynamicPadder HLO:"; + XLA_VLOG_LINES(2, module->ToString()); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(module)); + + for (HloComputation* computation : module->computations()) { + for (HloInstruction* inst : computation->instructions()) { + for (int64 operand_num = 0; operand_num < inst->operand_count(); + ++operand_num) { + HloInstruction* operand = inst->mutable_operand(operand_num); + if (!operand->shape().IsArray()) { + continue; + } + for (int64 dim = 0; dim < operand->shape().rank(); ++dim) { + HloInstruction* dynamic_size = + dynamic_dimension_inference.GetDynamicSize(operand, {}, dim); + if (dynamic_size == nullptr) { + continue; + } + VLOG(1) << "Has dynamic dimension of operand" << operand_num << " @" + << dim; + TF_ASSIGN_OR_RETURN(HloInstruction * identity_value, + ChooseIdentityValue(inst)); + if (identity_value == nullptr) { + continue; + } + + // For each dimension, first generates a mask representing the + // effective area of data and padded area of data using iota and + // dynamic_size. For example, given a dimension of 7 elements and 5 + // effective elements: + // + // iota = [0, 1, 2, 3, 4, 5, 6] + // broadcast_dynamic_size = [5, 5, 5, 5, 5, 5, 5] + // mask = lt(iota, broadcast_dynamic_size) = [t, t, t, t, t, f, f] + // + // Once the mask is generated, the input data is then padded using the + // mask and pad value. + // + const Shape mask_shape = + ShapeUtil::ChangeElementType(operand->shape(), xla::U32); + const Shape pred_shape = + ShapeUtil::ChangeElementType(operand->shape(), xla::PRED); + HloInstruction* iota = computation->AddInstruction( + HloInstruction::CreateIota(mask_shape, dim)); + + HloInstruction* broadcasted_effective_size = + computation->AddInstruction(HloInstruction::CreateBroadcast( + mask_shape, dynamic_size, {})); + HloInstruction* pred = computation->AddInstruction( + HloInstruction::CreateBinary(pred_shape, HloOpcode::kLt, iota, + broadcasted_effective_size)); + + HloInstruction* broadcasted_identity_value = + computation->AddInstruction(HloInstruction::CreateBroadcast( + operand->shape(), identity_value, {})); + HloInstruction* padded = + computation->AddInstruction(HloInstruction::CreateTernary( + operand->shape(), HloOpcode::kSelect, pred, operand, + broadcasted_identity_value)); + TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded)); + operand = inst->mutable_operand(operand_num); + changed = true; + } + } + } + } + HloDCE dce; + TF_ASSIGN_OR_RETURN(changed, dce.Run(module)); + VLOG(2) << "Post DynamicPadder HLO:"; + XLA_VLOG_LINES(2, module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.h b/tensorflow/compiler/xla/service/dynamic_padder.h new file mode 100644 index 0000000000000000000000000000000000000000..509269f7f56746fa5516ad917a04221587c6dcca --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// With bounded shapes, only part of the shape contains effective data and the +// rest contains padded data, whose value can be anything depending on the +// source of the data. When a bounded shape is directly consumed by an +// instruction that collapses dimensions (reduce for example), the padding data +// would affect result of the instruction. +// +// DynamicPadder uses DynamicDimensionInference to detect bounded shapes in a +// hlo module, it then inserts certain instructions to reset the padding into an +// identity value so that in doesn't affect the result of subsequent +// instruction. For example, it'd reset the padding to 0 before a bounded shape +// is consumed by a reduce-sum. +class DynamicPadder : public HloModulePass { + public: + absl::string_view name() const override { return "dynamic_padder"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..55a11286e4596d87c330315322cae704fc5cd707 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -0,0 +1,152 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_padder.h" + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class DynamicPadderTest : public HloTestBase { + protected: + DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); } + + StatusOr RunPadder() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before padder"); + + DynamicPadder padder; + + return padder.Run(module_.get()); + } + + void ExpectPadded(const HloInstruction* inst) { + EXPECT_THAT(inst, + op::Select(op::Lt(op::Iota(), op::Broadcast(op::Parameter())), + ::testing::_, op::Broadcast())); + } + + HloComputation* GetScalarAddComputation() { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + + std::unique_ptr module_; + const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {}); +}; + +TEST_F(DynamicPadderTest, ReduceTest) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, negate, init, {0, 2}, GetScalarAddComputation())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunPadder().status()); + + ExpectPadded(reduce->operand(0)); +} + +TEST_F(DynamicPadderTest, ConvolutionTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0); + + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(1); + dnums.set_output_feature_dimension(0); + + Window window; + + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + zx_shape, a_param, b_param, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunPadder().status()); + + ExpectPadded(conv->operand(0)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 6f928fcbaab53d804d580ba31f1b7fad0ed0bdc5..f84c115e0abec378f0401a15e6bd381983d0ff34 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1758,9 +1758,18 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(hlo->operand(1))(dim_index)); + // TODO(b/118437727): Remove the R1 path. + llvm::Value* start_index_value; + if (hlo->operand(1)->shape().rank() == 1) { + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); + TF_ASSIGN_OR_RETURN(start_index_value, + operand_to_generator.at(hlo->operand(1))(dim_index)); + } else { + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + start_index_value, + operand_to_generator.at(hlo->operand(1 + i))(zero_index)); + } // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) @@ -1905,9 +1914,19 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(start_hlo)(dim_index)); + + llvm::Value* start_index_value; + // TODO(b/118437727): Remove the R1 path. + if (hlo->operand(2)->shape().rank() == 1) { + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); + TF_ASSIGN_OR_RETURN(start_index_value, + operand_to_generator.at(hlo->operand(2))(dim_index)); + } else { + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + start_index_value, + operand_to_generator.at(hlo->operand(2 + i))(zero_index)); + } // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 6c23f921f40cac0dc5df08494dc1b63e6d1d5e93..2be5d918eeabf4db812425ccc2a9c4ba067ddaac 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3,6 +3,11 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_test") licenses(["notice"]) # Apache 2.0 @@ -24,12 +29,6 @@ filegroup( ]), ) -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", -) - xla_proto_library( name = "backend_configs", srcs = ["backend_configs.proto"], @@ -94,8 +93,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -135,6 +134,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm//:core", @@ -263,7 +264,9 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -362,6 +365,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -695,6 +699,8 @@ cc_library( "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", + "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -725,6 +731,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor/cuda:cuda_diagnostics", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 528209abc75777440163c2e1512658b8ad36315b..eb59ee5a1d47b6b706ef3f53a76069b3538eb6b7 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -57,16 +58,16 @@ StatusOr> BufferAllocations::Builder::Build( // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. - if (registered_buffers_.count(i)) { - se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i); - if (reinterpret_cast(address.opaque()) % expected_alignment != + if (se::DeviceMemoryBase* address = + tensorflow::gtl::FindOrNull(registered_buffers_, i)) { + if (reinterpret_cast(address->opaque()) % expected_alignment != 0) { return InternalError( "Address of registered buffer %d must be a multiple of %x, but " "was %p", - i, kEntryParameterAlignBytes, address.opaque()); + i, kEntryParameterAlignBytes, address->opaque()); } - buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i)); + buffer_allocations->SetBuffer(i, *address); continue; } diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index 14186b8faa68ad8492ea4863fcd7bd746e2eae48..9413ac2cff7c8d3ec4be6662569c580060bf1173 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -52,7 +53,8 @@ class BufferAllocations { DeviceMemoryAllocator* memory_allocator); private: - std::map registered_buffers_; + absl::flat_hash_map + registered_buffers_; }; ~BufferAllocations(); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 6d6780fa1c7b0c636eb771c40e74f074cd8c4c4b..309b0aca64954e64509d731dce28ce9d8da4ee43 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -146,7 +146,8 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. StatusOr -CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { +CudnnConvAlgorithmPicker::PickBestAlgorithm( + const HloCustomCallInstruction* instr) { // TODO(timshen): for now only check fp16. It can be expanded to other types, // with some work on the HLO routines. const bool cross_check_enabled = @@ -249,12 +250,13 @@ CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - backend_config.set_algorithm(alg.algo_id()); - backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled()); - TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config)); + // Use assignment instead of brace-list to make GCC 4.9 happy. + RunConvOptions options; + options.profile_result = &profile_result; + options.algo_override = alg; bool launch_ok = RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, - &scratch_allocator, &stream, &profile_result) + &scratch_allocator, &stream, options) .ok(); if (launch_ok && profile_result.is_valid()) { diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h index 642af787afc71586d722ecc7e529ed8b3fa64d33..4991db0948589e479a202f4082d96df275f6e088 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h @@ -56,7 +56,8 @@ class CudnnConvAlgorithmPicker : public HloModulePass { StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr PickBestAlgorithm(HloCustomCallInstruction* instr); + StatusOr PickBestAlgorithm( + const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index 3425e1b4942aaf1011ba1bf1c50dd7e79c1f9807..b628f27f4b2ba8ccf17fd531d8a0c25cb99d9396 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -395,32 +395,36 @@ Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { + RunConvOptions options) { ScratchBufAllocator scratch_allocator(scratch_buf); return RunCudnnConv(conv, operand_buffers, result_buffer, &scratch_allocator, - stream, profile_result); + stream, options); } Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { + RunConvOptions options) { TF_ASSIGN_OR_RETURN(CudnnConvParams params, GetCudnnConvParams(conv, operand_buffers, result_buffer)); + if (options.algo_override) { + params.algorithm = AlgorithmConfig(*options.algo_override); + } + PrimitiveType output_primitive_type = conv->shape().tuple_shapes(0).element_type(); switch (output_primitive_type) { case F16: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); case F32: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); case F64: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); default: LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h index edbc75a94a1238540390b93f0fa5217852c7781f..25b2461ca61251c6cb7b89b1f91da0f1636a3647 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h @@ -28,6 +28,14 @@ limitations under the License. namespace xla { namespace gpu { +struct RunConvOptions { + // Nullable output-parameter pointer for profiling results. + se::dnn::ProfileResult* profile_result = nullptr; + + // Use this algorithm, instead of the one from the instruction. + absl::optional algo_override; +}; + // This file contains low-level routines for running cudnn convolutions. // Calls into cudnn to run the specified convolution. @@ -46,13 +54,13 @@ Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); + RunConvOptions = {}); Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); + RunConvOptions = {}); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 9634c92786ad4dd0d62bae9cbb087b4ef6b5866f..69aaaceca112364a4fd562f6a5eff1629fd3fc54 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -45,10 +46,10 @@ void HloToIrBindings::EmitBasePointersForHlos( // An HLO can have duplicated operands. This data structure remembers which // operand HLOs are already bound to avoid rebinding the same HLO. - std::set already_bound_for_this_function; + absl::flat_hash_set already_bound_for_this_function; auto arg_iter = function->arg_begin(); for (const HloInstruction* io_hlo : io_hlos) { - if (!already_bound_for_this_function.count(io_hlo)) { + if (!already_bound_for_this_function.contains(io_hlo)) { if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter)); } else { @@ -63,7 +64,7 @@ void HloToIrBindings::EmitBasePointersForHlos( temp_buffer_base_->setName("temp_buffer"); for (const HloInstruction* non_io_hlo : non_io_hlos) { - if (already_bound_for_this_function.count(non_io_hlo)) { + if (already_bound_for_this_function.contains(non_io_hlo)) { continue; } already_bound_for_this_function.insert(non_io_hlo); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index c0edae530cedba45c897b07b7b9cc72eaaab397c..f57b594e9c18078a3bbbf4d2b4db7e989c4edfdd 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -61,7 +62,7 @@ class HloToIrBindings { // Returns whether `hlo` is bound to an LLVM IR value. bool BoundToIrValue(const HloInstruction& hlo) const { - return base_ptrs_.count(&hlo); + return base_ptrs_.contains(&hlo); } llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } @@ -110,7 +111,8 @@ class HloToIrBindings { // For an instruction that generates multiple outputs, the root will be a // tuple shape. The IrArray for each element output is stored in the subnode // in the ShapeTree. - std::unordered_map> base_ptrs_; + absl::flat_hash_map> + base_ptrs_; // The address of the memory block that contains all temporary buffers. llvm::Value* temp_buffer_base_ = nullptr; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 5d25a032a99b2e477ba81e9d8ea31b2e1ddf93ec..caccb1889977de4601889d4cc39ac74418ac6652 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -154,20 +154,17 @@ bool IsReductionToVector(const HloInstruction& reduce) { const HloInstruction* input = reduce.operand(0); std::vector dims_to_keep; for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) { - if (!std::count(reduce.dimensions().begin(), reduce.dimensions().end(), - dim)) { + if (!absl::c_linear_search(reduce.dimensions(), dim)) { dims_to_keep.push_back(dim); } } return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), dims_to_keep) && - ShapeUtil::Equal(reduce.shape(), ShapeUtil::FilterDimensions( - [&dims_to_keep](int64 dim) { - return std::count( - dims_to_keep.begin(), - dims_to_keep.end(), dim); - }, - input->shape())); + ShapeUtil::Equal( + reduce.shape(), + ShapeUtil::FilterDimensions( + [&](int64 dim) { return absl::c_count(dims_to_keep, dim); }, + input->shape())); } // This emits a device-side call to diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 063d493d90a2330c535ed4a16d063d6db9b49cd1..34ddeb1d417729e467cebba4986a763dc685c05b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -88,6 +88,9 @@ namespace xla { namespace gpu { using llvm_ir::KernelMappingScheme; +using EmitElementFunction = + std::function; namespace { @@ -1506,10 +1509,10 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( return !allocation->is_constant(); }); - std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), - [](const BufferAllocation* a, const BufferAllocation* b) { - return a->index() < b->index(); - }); + absl::c_sort(non_constant_buffers, + [](const BufferAllocation* a, const BufferAllocation* b) { + return a->index() < b->index(); + }); llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); @@ -2133,53 +2136,86 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( namespace { -void EmitFullElementalTile( - const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, const string& loop_name, - KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, - llvm::Value* x, llvm::Type* index_ty, - const std::function& emit_elem_function) { +std::tuple GetStartOffsetAndStepForX( + int64 tile_size_x, int64 num_threads_x, + const KernelMappingScheme* mapping_scheme, llvm::IRBuilder<>* builder, + llvm::Value* x, llvm::Type* index_ty) { + llvm::Value* start_offset_x; + int64 step_x; + if (mapping_scheme->DilatedX()) { + start_offset_x = x; + step_x = num_threads_x; + } else { + start_offset_x = builder->CreateMul( + x, llvm::ConstantInt::get(index_ty, tile_size_x / num_threads_x)); + step_x = 1; + } + return std::make_tuple(start_offset_x, step_x); +} + +void EmitFullElementalTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + const string& loop_name, KernelSupportLibrary* ksl, + llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Type* index_ty, + const EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); + + llvm::Value* start_offset_x; + int64 step_x; + std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX( + tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty); + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder) + .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder); ksl->For(loop_name + "_y", /*start=*/llvm::ConstantInt::get(index_ty, 0), /*end=*/llvm::ConstantInt::get(index_ty, tile_size_y), /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), [&](llvm::Value* y_indvar) { - IrArray::Index source_idx_y = tile_origin_index.AddOffsetToDim( + IrArray::Index source_idx_y = source_idx.AddOffsetToDim( y_indvar, KernelMappingScheme::DimY, builder); llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); - for (int64 j = 0; j < tile_size_x; j += num_threads_x) { - IrArray::Index source_idx = source_idx_y.AddOffsetToDim( - llvm::ConstantInt::get(index_ty, j), + + for (int64 j = 0; j < tile_size_x / num_threads_x; j++) { + IrArray::Index source_idx_y_x = source_idx_y.AddOffsetToDim( + llvm::ConstantInt::get(index_ty, j * step_x), KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - emit_elem_function(source_idx, y_loc, x_loc); + llvm::Value* x_loc = builder->CreateAdd( + llvm::ConstantInt::get(index_ty, j * step_x), + start_offset_x); + emit_elem_function(source_idx_y_x, y_loc, x_loc, j); } }); } -void EmitPartialElementalTile( - const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, const string& loop_name, - KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, - llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, - llvm::Type* index_ty, - const std::function& emit_elem_function) { +void EmitPartialElementalTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + const string& loop_name, + KernelSupportLibrary* ksl, + llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, + llvm::Value* tile_width, llvm::Type* index_ty, + const EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); - for (int64 j = 0; j < tile_size_x; j += num_threads_x) { - IrArray::Index source_idx = - tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), - KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + llvm::Value* start_offset_x; + int64 step_x; + std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX( + tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty); + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder) + .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder); + for (int64 j = 0; j < tile_size_x / num_threads_x; j++) { + IrArray::Index source_idx_x = + source_idx.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j * step_x), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = builder->CreateAdd( + llvm::ConstantInt::get(index_ty, j * step_x), start_offset_x); ksl->If( loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width), @@ -2199,14 +2235,13 @@ void EmitPartialElementalTile( /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), [&](llvm::Value* y_indvar) { llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); - ksl->If( - loop_name + "_y_in_tile", - builder->CreateICmpULT(y_loc, tile_height), [&] { - emit_elem_function( - source_idx.AddOffsetToDim( - y_indvar, KernelMappingScheme::DimY, builder), - y_loc, x_loc); - }); + ksl->If(loop_name + "_y_in_tile", + builder->CreateICmpULT(y_loc, tile_height), [&] { + emit_elem_function( + source_idx_x.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder), + y_loc, x_loc, j); + }); }); }); } @@ -2225,8 +2260,7 @@ void EmitTiledElementalCodeWithBoundsCheck( const IrArray::Index& tile_origin_index, const string& loop_name, KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, - const std::function& emit_elem_function) { + const EmitElementFunction& emit_elem_function) { int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); @@ -2262,7 +2296,7 @@ void EmitTiledElementalCodeWithBoundsCheck( void IrEmitterUnnested::EmitTileElementForCopy( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 /*x_iter_num*/) { llvm_ir::TiledParameterInfo* tiled_param_info = kernel_info->GetTiledParameterInfo(); // TODO(jlebar): Add AA metadata to this load. @@ -2292,7 +2326,7 @@ void IrEmitterUnnested::EmitTileElementForCopy( void IrEmitterUnnested::EmitTileElementForFusion( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 /*x_iter_num*/) { llvm_ir::TiledParameterInfo* tiled_param_info = kernel_info->GetTiledParameterInfo(); std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); @@ -2393,6 +2427,23 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { : llvm_ir::KernelMappingScheme::DimX; } + int GetNumberOfPartialResults() const { + if (IsRowReduction()) { + return 1; + } + int64 num_thread = mapping_scheme_->GetNumberOfThreadsForDimensionX(); + int64 tile_size = mapping_scheme_->GetTileSizeForDimensionX(); + CHECK_EQ(tile_size % num_thread, 0); + return tile_size / num_thread; + } + + int GetPartialResultIndex(int64 x_iter_num) const { + if (IsRowReduction()) { + return 0; + } + return x_iter_num; + } + private: AddressVector partial_result_addresses_; AddressVector reduction_input_addresses_; @@ -2452,10 +2503,11 @@ void IrEmitterUnnested::EmitPrologueForOneReduction( llvm::AllocaInst* reduction_input_address = Alloca(element_type); reduction_input_addresses->push_back(reduction_input_address); + int num_partial_results = reduction_info->GetNumberOfPartialResults(); AddressVector* partial_result_addresses = reduction_info->GetMutablePartialResultAddresses(); llvm::AllocaInst* partial_result_address = - Alloca(element_type, /*ArraySize=*/nullptr, + Alloca(element_type, /*ArraySize=*/b_.getInt32(num_partial_results), "partial_reduction_result." + llvm::Twine(reduce_idx)); partial_result_addresses->push_back(partial_result_address); @@ -2478,7 +2530,9 @@ void IrEmitterUnnested::EmitPrologueForOneReduction( .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_); } - Store(init_ir_value, partial_result_address); + for (int i = 0; i < num_partial_results; ++i) { + Store(init_ir_value, InBoundsGEP(partial_result_address, {b_.getInt32(i)})); + } } void IrEmitterUnnested::EmitPrologueForReduction( @@ -2516,10 +2570,14 @@ void IrEmitterUnnested::EmitPrologueForReduction( std::move(output_shape_index)); } - // Allocate stack storage to store the current output linear index and record - // the address of the storage. + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + + // Allocate stack storage to store the linear indices for the current output, + // and record the address of the storage. reduction_info->SetCurrentOutputLinearIndexAddress( - Alloca(reduction_info->GetIndexType())); + Alloca(reduction_info->GetIndexType(), + /*ArraySize=*/b_.getInt32(num_partial_results), + "current_output_linear_index_address")); if (!reduction_info->IsRowReduction()) { llvm::Type* bool_ty = b_.getInt1Ty(); @@ -2589,36 +2647,45 @@ void IrEmitterUnnested::EmitEpilogueForReduction( llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_); } + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + // Emit an atomic operation that accumulates the partial reduction to the // output element. For row reduction, this is only for lane 0 due to the // if-statement emitted above. for (int i = 0; i != num_reduces; ++i) { - IrArray::Index element_index( - /*linear=*/Load(reduction_info->GetCurrentOutputLinearIndexAddress(), - "output_linear_addr"), - ShapeUtil::GetSubshape(unnested_hlo->shape(), - reduction_output_shape_indices[i]), - &b_); - llvm::Value* output_address = - GetIrArray(*unnested_hlo, *unnested_hlo, - reduction_output_shape_indices[i]) - .EmitArrayElementAddress(element_index, &b_, - "output_element_address"); - // Do not emit atomic operations if each element in the reduction result is - // computed by one block, that is the dimension being reduced has only one - // block. - const llvm_ir::KernelMappingScheme* mapping_scheme = - reduction_info->GetKernelMappingScheme(); - if (mapping_scheme->GetTileBlockSizeForDimension( - llvm_ir::KernelMappingScheme::DimZ) == 1 && - mapping_scheme->GetTileBlockSizeForDimension( - reduction_info->GetReducedDimensionEnum()) == 1) { - TF_CHECK_OK(EmitCallToNestedComputation( - *reducers[i], {output_address, partial_result_addresses[i]}, - output_address)); - } else { - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_result_addresses[i])); + for (int j = 0; j < num_partial_results; ++j) { + IrArray::Index element_index( + /*linear=*/Load( + InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), + {b_.getInt32(j)}), + "output_linear_addr"), + ShapeUtil::GetSubshape(unnested_hlo->shape(), + reduction_output_shape_indices[i]), + &b_); + llvm::Value* output_address = + GetIrArray(*unnested_hlo, *unnested_hlo, + reduction_output_shape_indices[i]) + .EmitArrayElementAddress(element_index, &b_, + "output_element_address"); + // Do not emit atomic operations if each element in the reduction result + // is computed by one block, that is the dimension being reduced has only + // one block. + const llvm_ir::KernelMappingScheme* mapping_scheme = + reduction_info->GetKernelMappingScheme(); + if (mapping_scheme->GetTileBlockSizeForDimension( + llvm_ir::KernelMappingScheme::DimZ) == 1 && + mapping_scheme->GetTileBlockSizeForDimension( + reduction_info->GetReducedDimensionEnum()) == 1) { + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], + {output_address, + InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)})}, + output_address)); + } else { + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)}))); + } } } } @@ -2626,7 +2693,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction( void IrEmitterUnnested::EmitTileElementForReduction( HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 x_iter_num) { VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion ? unnested_hlo->fused_expression_root() @@ -2639,8 +2706,11 @@ void IrEmitterUnnested::EmitTileElementForReduction( // Record the linear address for the current reduction. const ReductionCodegenInfo* reduction_info = dynamic_cast(kernel_info); + int partial_result_index = reduction_info->IsRowReduction() ? 0 : x_iter_num; + Store(index[reduction_info->GetKeptDimensionEnum()], - reduction_info->GetCurrentOutputLinearIndexAddress()); + InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), + {b_.getInt32(partial_result_index)})); if (!reduction_info->IsRowReduction()) { llvm::Type* bool_ty = b_.getInt1Ty(); llvm::AllocaInst* output_inbound_addr = @@ -2687,6 +2757,13 @@ void IrEmitterUnnested::EmitTileElementForReduction( reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( index, GetFirstReduceInstruction(output_instructions)->operand(0)->shape()); + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + if (num_partial_results > 1) { + // Clear the linear index field of the IrArray::Index to enable the use of + // GetElementPointer with array types. This enables the vectorization of + // the computation for different partial results. + input_index.ClearLinearIndex(); + } absl::Span partial_reduction_result_addresses = reduction_info->GetPartialResultAddresses(); absl::Span reduction_input_addresses = @@ -2699,10 +2776,12 @@ void IrEmitterUnnested::EmitTileElementForReduction( for (int i = 0; i != reducers.size(); ++i) { llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie(); Store(input_ir_value, reduction_input_addresses[i]); + llvm::Value* partial_result_address = + InBoundsGEP(partial_reduction_result_addresses[i], + {b_.getInt32(partial_result_index)}); TF_CHECK_OK(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], reduction_input_addresses[i]}, - partial_reduction_result_addresses[i])); + *reducers[i], {partial_result_address, reduction_input_addresses[i]}, + partial_result_address)); } // Emit code to generate the output for the non-reduction instructions in the @@ -2713,8 +2792,8 @@ void IrEmitterUnnested::EmitTileElementForReduction( // Emits a kernel for the hlo instruction using the given tiling scheme. void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, - const KernelCodegenInfo* kernel_info, - KernelSupportLibrary& ksl, + KernelCodegenInfo* kernel_info, + KernelSupportLibrary* ksl, llvm::Type* index_ty) { KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); absl::Span dims_in_tile = mapping_scheme->GetDimensionsInTiles(); @@ -2747,15 +2826,14 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, llvm::Value* num_tiles_in_block = Select(ICmpEQ(last_block_for_dim, block_id_for_dim), last_block_size_for_dim, block_size_for_dim); - - ksl.For(loop_name, - /*start=*/index_typed_constant(0), - /*end=*/num_tiles_in_block, - /*step=*/1, [&](llvm::Value* block_dim_induction_var) { - IrArray::Index tile_index = starting_tile.AddOffsetToDim( - block_dim_induction_var, dim_id, &b_); - emit_next_block_dim(tile_index); - }); + ksl->For(loop_name, + /*start=*/index_typed_constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_id, &b_); + emit_next_block_dim(tile_index); + }); } }; @@ -2810,7 +2888,8 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, // unnested_hlo: The unnested hlo instruction for which the kernel is generated. // Currently, these hlo instructions are supported: kLoop fusion, kCopy. // tiled_param_ids: The IDs for the parameters that are 0-2-1 transpose of -// other tensors with the same dimensions and need to be tiled and tranposed. +// other tensors with the same dimensions and are safe to be tranposed via +// the shared memory tranpose implementation. // mapping_scheme: The tiling scheme to use. // kernel_generator: Contains function objects for code generation, such as // element generator, block prologue and epilogue generators. @@ -2898,8 +2977,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( auto emit_tiled_elemental_code_with_bounds_check = [&](const IrArray::Index& index, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, - const std::function& emit_elem_function) { + const EmitElementFunction& emit_elem_function) { EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name, &ksl, &b_, y, x, tile_height, tile_width, emit_elem_function); @@ -2912,10 +2990,6 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( const IrArray::Index input_tile_origin( Permute({0, 2, 1}, output_tile_origin.multidim())); - const IrArray::Index input_index = - input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) - .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // If shared memory transpose is needed, wait for all threads to reach this // point, lest we copy a value from tile to output before the other thread // copies it from input to tile. This is `__syncthreads` in CUDA. @@ -2925,9 +2999,10 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // Note that tile_width and tile_height are flipped here because we are // reading a transposed tile. emit_tiled_elemental_code_with_bounds_check( - input_index, "input", output_tile_bounds[2], output_tile_bounds[1], + input_tile_origin, "input", output_tile_bounds[2], + output_tile_bounds[1], [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 /*x_iter_num*/) { for (int64 id : tiled_param_ids) { IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; @@ -2947,18 +3022,15 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); kernel_info->SetTiledParamInfo(&tiled_param_info); - const IrArray::Index output_index = - output_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) - .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // Write to output[index] by emitting code like normal, except that values // for the tiled parameters are read from the shmem buffers. emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[1], output_tile_bounds[2], - [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - kernel_generator.GetTileElementGenerator()(unnested_hlo, index, - kernel_info, y_loc, x_loc); + output_tile_origin, "output", output_tile_bounds[1], + output_tile_bounds[2], + [&](const IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num) { + kernel_generator.GetTileElementGenerator()( + unnested_hlo, index, kernel_info, y_loc, x_loc, x_iter_num); }); // If a tile block contains multiple tiles and shared memory buffers are @@ -2976,7 +3048,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( block_prologue_generator(unnested_hlo, kernel_info); } - EmitBlock(std::move(emit_one_tile), kernel_info, ksl, index_ty); + EmitBlock(std::move(emit_one_tile), kernel_info, &ksl, index_ty); const BlockEpilogueGenerator& block_epilogue_generator = kernel_generator.GetBlockEpilogueGenerator(); @@ -2989,7 +3061,10 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose // algorithm to improve the memory access patterns for the input parameters -// with a shape that is a 0-2-1 transpose of the output tensor shape. +// with a shape that is a 0-2-1 transpose of the output tensor shape. The caller +// is responsible for making sure that it is safe to apply the shared memory +// tranpose on the input parameters. +// // // For the purpose of tiling, the output tensors have a logical shape of three // components 0-2-1 while the relevant input parameters have a logical shape @@ -3022,17 +3097,19 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( element_generator = [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc) { - EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num) { + EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc, x_iter_num); }; } else { DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion); - element_generator = [&](HloInstruction* hlo, - const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc) { - EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc); - }; + element_generator = + [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc, int64 x_iter_num) { + EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc, + x_iter_num); + }; } KernelCodegenInfo kernel_info(&mapping_scheme); KernelCodeGenerator kernel_generator(std::move(element_generator)); @@ -3040,26 +3117,99 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( } namespace { -// Returns true to indicate it is safe to use the tile based shared memory -// transpose implementation to implement the kernel for the instruction. +// A recursive function to inspect the users of a parameter to determine +// whether it's safe for a parameter to participate in a shared-memory +// transpose. // -// An instruction is not safe for such an implementation if it can change the -// element order of a tensor without changing the dimension of the tensor, and -// the instruction has a corresponding elemental_ir_emitter. -bool IsInstructionSafeForTileBasedTranspose(const HloInstruction* hlo) { - auto is_safe_for_tile_based_transpose = [&](const HloInstruction* instr) { - HloOpcode opcode = instr->opcode(); - CHECK_NE(opcode, HloOpcode::kFusion); - return (opcode != HloOpcode::kReverse && opcode != HloOpcode::kGather); - }; +// Consider a fusion parameter P for which we might want to use a shmem +// transpose. If we do, we use a GPU thread block to preload a tile of P with +// indices [z, y..y+31, x..x+31] to compute an output tile with the same indices +// cooperatively, where z, y, x are the indices for the normalized input/output +// tensor (see the document for FindTranspose021 for the definition of +// normalized tensor for 0-2-1 transpose). This shmem transpose implementation +// requires that the computation of the output tile only read elements within +// the preload tile. If this is not true, we can't use a shmem transpose for P. +// +// If the computation of output element [z, y, x] only requires the element of +// P with the same indices, the shmem tranpose implementation can be applied +// to P safely. This is a sufficient but not necessary condition. We check all +// the transitive users of P to see if we can find a user that may cause an +// exception to the situation. If such a user is not found, we conclude that P +// is safe for shmem transpose. +// +// This is trivially true for elementwise operations and some "data-movement" +// ops like kTuple. However, it's not true for operations that can change the +// dimensions of the inputs (e.g. pad, slice) and bitcast operation. +// For example: +// +// fused_computation { +// param_0 = f32[64,64]{1,0} parameter(0) +// ROOT bitcast = f32[64,64]{0,1} bitcast(param_0) +// } +// The output element at logical address [0, 63] depends on the input element +// at logical address [63, 0], which would not be within the shared-memory +// block. +// +// TODO(bixia): In order to extend this for kInput fusion, that is reduction +// with tranpose, we only need to end the use-chain checking with the input of +// a reduce operations. In this case, the above description on "output" apply +// to the result of such a use-chain, which provides the input to the reduce +// operation. +bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) { + if (hlo->IsElementwise()) { + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return IsInstructionSafeForShmemTranspose(user); + }); + } + + switch (hlo->opcode()) { + // Non-elementwise instructions that don't cause the shmem transpose + // to be unsafe, including the instructions that don't currently fuse. + case HloOpcode::kGetDimensionSize: + // The result of the operation doesn't rely on the content of the + // tensor. As such, there is no need to further inspect its users. + return true; + case HloOpcode::kGetTupleElement: + case HloOpcode::kMap: + case HloOpcode::kParameter: + case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return IsInstructionSafeForShmemTranspose(user); + }); - if (hlo->opcode() == HloOpcode::kFusion) { - return absl::c_all_of(hlo->fused_instructions_computation()->instructions(), - is_safe_for_tile_based_transpose); + default: + return false; } +} - return is_safe_for_tile_based_transpose(hlo); +// Given a group of input parameters that are 0-2-1 tranpose of the outputs of +// a fusion kernel, returns the input parameters that are safe for the shared +// memory tranpose implementation. +// +// When a tile based shared memory transpose is used to implement an input with +// 0-2-1 transpose, we preload a tile of the input elements +// [z, y..y+31, x..x+31] to compute the output tile elements of the same +// indices. Preloading the input tile this way is only safe when the computation +// of the output tile elements do not need any input element outside the +// preloaded tile. We inspect all the transitive users of the input parameter +// up to the fusion root instruction to see if we can find any instruction +// that can make preloading the input tile unsafe. +std::vector FilterInputsForShmemTranspose(const HloInstruction* fusion, + std::vector input_ids) { + std::vector filtered_input_ids; + for (int64 i = 0; i < input_ids.size(); ++i) { + const HloInstruction* input = fusion->fused_parameter(input_ids[i]); + if (IsInstructionSafeForShmemTranspose(input)) { + filtered_input_ids.push_back(input_ids[i]); + } else { + VLOG(10) << "Input not safe for shmem transpose " << input->ToString() + << "\n"; + } + } + return filtered_input_ids; } + } // namespace bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { @@ -3106,8 +3256,11 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return false; } - if (!IsInstructionSafeForTileBasedTranspose(hlo)) { - return false; + if (opcode == HloOpcode::kFusion) { + params_012 = FilterInputsForShmemTranspose(hlo, params_012); + if (params_012.empty()) { + return false; + } } // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the @@ -3250,11 +3403,101 @@ std::tuple GetReductionToVectorDimensions( return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); } +// Returns true if all the transitive users of hlo before hitting users in +// use_chain_endings are elementwise operations. +bool AreUsersElementwise(const HloInstruction* hlo, + const ConstHloInstructionSet& use_chain_endings) { + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return use_chain_endings.count(user) || + (user->IsElementwise() && + AreUsersElementwise(user, use_chain_endings)); + }); +} + +// Returns the number of fusion inputs that have the same dimension as the +// given shape, and involve in only elementwise operations. +int64 NumInputsInvolveInOnlyElementwiseOps( + const HloInstruction* unnested_hlo, const Shape& op_shape, + const ConstHloInstructionSet& use_chain_endings) { + return absl::c_count_if( + unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) { + const Shape& parameter_shape = parameter->shape(); + return ShapeUtil::SameDimensions(op_shape, parameter_shape) && + AreUsersElementwise(parameter, use_chain_endings); + }); +} + +// Returns the number of fusion inputs that have more elements than the given +// shape. +int64 NumInputsWithMoreElementsThan(const HloInstruction* unnested_hlo, + const Shape& shape) { + int64 num_elements = ShapeUtil::ElementsIn(shape); + return absl::c_count_if( + unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) { + return ShapeUtil::ElementsIn(parameter->shape()) > num_elements; + }); +} + +// The benefit of unrolling a kInput fusion that is a column reduction comes +// from the vectorization of non-reduction fusion outputs and fusion inputs. +// On the other hand, unrolling can also introduce factors that can cause +// the kernel to run slower. This routine uses a simple heuristic to estimate +// the benefit as well as the overhead of unrolling in order to decide whether +// unrolling is beneficial for the given kInput fusion. +bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo, + const Shape& input_shape, + int64 num_kept) { + // TODO(b/122468062): Need further investigate to see whether we can + // remove the constraint on IsPowerOfTwo. + if (!IsPowerOfTwo(static_cast(num_kept))) { + return false; + } + + if (unnested_hlo->opcode() == HloOpcode::kReduce) { + return true; + } + + CHECK_EQ(unnested_hlo->opcode(), HloOpcode::kFusion); + int64 can_be_vectorized = 0; + int64 cannot_be_vectorized = 0; + const HloInstruction* fused_root = unnested_hlo->fused_expression_root(); + ConstHloInstructionSet use_chain_endings; + if (fused_root->opcode() == HloOpcode::kReduce) { + use_chain_endings.insert(fused_root); + // Atomic.add of the reduction result can't be vectorized. + cannot_be_vectorized++; + } else { + CHECK_EQ(fused_root->opcode(), HloOpcode::kTuple); + for (const HloInstruction* instr : fused_root->operands()) { + if (instr->opcode() == HloOpcode::kReduce) { + // Atomic.add of the reduction result can't be vectorized. + cannot_be_vectorized++; + } else { + // Write of the non-reduction result can be vectorized. + can_be_vectorized++; + } + use_chain_endings.insert(instr); + } + } + // Fusion inputs that have the same dimension as the reduce input and + // only involve in elementwise operations can be vectorized. + can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps( + unnested_hlo, input_shape, use_chain_endings); + // Fusion inputs with more elements than the reduce op input must participate + // in non-elementwise operations and we assume that they are not vectorizable + // for the purpose of estimating the benefit of unrolling. If the kernel is + // unrolled even with such an assumption, and the accesses to those inputs + // turn out to be vectorizable, the compiler will still vectorize them. + cannot_be_vectorized += + NumInputsWithMoreElementsThan(unnested_hlo, input_shape); + return can_be_vectorized >= cannot_be_vectorized; +} + } // namespace std::tuple IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( - const HloInstruction* first_reduce) { + const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) { int64 depth = 1; int64 height = 1; int64 width = 1; @@ -3271,6 +3514,7 @@ IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( std::tie(num_reduced_major, num_kept, num_reduced_minor) = GetReductionToVectorDimensions(input_shape, first_reduce->dimensions()); CHECK_EQ(num_output_elems, num_kept); + bool dilated_x = true; if (num_kept == 1) { // Scalar reduction is a special row reduction with depth = height = 1. @@ -3285,13 +3529,21 @@ IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( is_row_reduction = false; // Column reduction without transpose doesn't require communication among // threads processing elements in the same tile. The current implementation - // only support the use of on hardware thread block to process one block of - // tiles in the KernelMappingScheme. We try to maximize the values of + // only support the use of one hardware thread block to process one block of + // tiles in the KernelMappingScheme. We try to use one thread to compute + // the partial results for two tensor elements and to maximize the values of // num_threads_x and tile_size_x to allow a bigger hardware thread block. int64 hw_threads_per_block_limit = ThreadsPerBlockLimit(ir_emitter_context_->device_description()); - tile_size_x = std::min(hw_threads_per_block_limit, num_kept); - num_threads_x = tile_size_x; + if (IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape, + num_kept)) { + tile_size_x = std::min(2 * hw_threads_per_block_limit, num_kept); + num_threads_x = tile_size_x / 2; + dilated_x = false; + } else { + tile_size_x = std::min(hw_threads_per_block_limit, num_kept); + num_threads_x = tile_size_x; + } int64 kNumElementsPerPartialSum = 128; tile_size_y = kNumElementsPerPartialSum; } else { @@ -3320,6 +3572,7 @@ IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( llvm_ir::KernelMappingScheme mapping_scheme( dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y, num_threads_x, &b_); + mapping_scheme.SetDilatedX(dilated_x); return std::make_tuple(mapping_scheme, is_row_reduction); } @@ -3368,14 +3621,15 @@ Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { bool is_row_reduction; llvm_ir::KernelMappingScheme mapping_scheme; std::tie(mapping_scheme, is_row_reduction) = - ComputeMappingSchemeAndReductionKind(first_reduce); + ComputeMappingSchemeAndReductionKind(unnested_hlo, first_reduce); ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction); KernelCodeGenerator kernel_generator( /*tile_element_generator=*/ [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { - EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc); + llvm::Value* x_loc, int64 x_iter_num) { + EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc, + x_iter_num); }, /*block_prologue_generator=*/ [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index d217ee36cf6e9b5278024a2f78513232328e7538..21b842bb2cd63ac454f85556df20ae5877cecbe1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -76,7 +76,6 @@ class IrEmitterUnnested : public IrEmitter { void SetLaneId(llvm::Value* v) { lane_id_ = v; } void SetIndexType(llvm::Type* t) { index_ty_ = t; } void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { - CHECK_EQ(tiled_param_info_, nullptr); tiled_param_info_ = tiled_param_info; } @@ -89,7 +88,7 @@ class IrEmitterUnnested : public IrEmitter { } llvm::Type* GetIndexType() const { return index_ty_; } - private: + protected: llvm_ir::KernelMappingScheme* mapping_scheme_; llvm_ir::TiledParameterInfo* tiled_param_info_; llvm::Value* lane_id_; @@ -109,10 +108,12 @@ class IrEmitterUnnested : public IrEmitter { // y_loc: The y coordinate within a tile. // x_loc: The x coordinate within a tile. // kernel_info: Other information to support the kernel code generation. + // x_iter_num: When a thread process N elements in the X dimension, x_iter_num + // has a value of 0..N-1 to identify the element being process. using TileElementGenerator = std::function; + llvm::Value* x_loc, int64 x_iter_num)>; // KernelCodeGenerator records the code generator objects that generate code // for tile elements or tile block prologue/epilogue. @@ -216,9 +217,13 @@ class IrEmitterUnnested : public IrEmitter { Status EmitReductionToVector(HloInstruction* unnested_hlo); // Computes the KernelMappingScheme for the reduce HLO and indicates whether - // the reduction is a row reduction. + // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo + // and first_reduce are the same instruction. For a kInput fusion, + // unnested_hlo is the fusion instruction while first_reduce is the first + // reduce op. std::tuple - ComputeMappingSchemeAndReductionKind(const HloInstruction* first_reduce); + ComputeMappingSchemeAndReductionKind(const HloInstruction* unnested_hlo, + const HloInstruction* first_reduce); // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. `scatter` may be fused, scatter indices are taken from @@ -243,26 +248,29 @@ class IrEmitterUnnested : public IrEmitter { const KernelCodeGenerator& kernel_generator, KernelCodegenInfo* kernel_info); void EmitBlock(const TileGenerator& emit_one_tile, - const KernelCodegenInfo* kernel_info, - KernelSupportLibrary& ksl, llvm::Type* index_ty); + KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl, + llvm::Type* index_ty); // Emits code to process a tensor element in a tile for the given kCopy HLO // that performs a 0-2-1 transpose. void EmitTileElementForCopy(HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); // Emits code to process a tensor element in a tile for the given kLoop fusion // HLO containing parameters that are 0-2-1 transpose of its outputs. void EmitTileElementForFusion(HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); // Emits code to process a tensor element in a tile for the given input hlo // that is either a unnested kReduce or a kInput fusion. void EmitTileElementForReduction(HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); // Prepares for the code generation for a tile block of a reduction kernel. void EmitPrologueForReduction(HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index bd53b90b42d8e657a3ee58e7ca03fb60522aae28..eddaa877f28cb7bd32e924cd179814581eb97b12 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -125,6 +125,7 @@ static string GetSmName(std::pair compute_capability) { {{6, 2}, 62}, {{7, 0}, 70}, {{7, 2}, 72}, + {{7, 5}, 75}, }); int sm_version = 30; auto it = m->find(compute_capability); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 01fddcede64d1bb02ab89db5fc9524893c2d47a4..02e1207f377b8c28bf2566bee8cf3bcbc66794fb 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -67,7 +67,7 @@ int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, } int64 profit = 0; for (auto instr : instr2->operands()) { - if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) { + if (!IsProfitableOperand(instr) || !in_list.contains(instr)) { continue; } profit += ShapeUtil::ByteSizeOf(instr->shape()); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index d16c87ba5c63aa582753fe949e9e39ee2d8b81e5..40b87b16a195564c9b98497f79a70f1db0539d87 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -628,8 +628,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { p.1 = s32[1]{0} parameter(1) p.2 = f16[1,96,1024]{2,1,0} parameter(2) c.0 = s32[] constant(0) - pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } fusion.2 { @@ -638,7 +637,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { p.2 = f16[1,96,1024]{2,1,0} parameter(2) c.0 = s32[] constant(0) pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } ENTRY entry { diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index cd369d55987b96eed2efb64ae0df6b3a76acb672..d1522280baf9a153d0731abfbc5a683df6b1cc53 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" @@ -152,6 +154,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pipeline.AddPass(); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), @@ -163,6 +166,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // We need a cost model for GPUs. Currently, do nothing. return false; }; + pipeline.AddPass(false); pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/true); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 4775baf44aecfe6adaf2bf0d2791595436635b16..1dedbd3befce6e2ceb06126d83a061207a90dd8f 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -25,7 +26,7 @@ namespace xla { namespace gpu { bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const { - return hlo_to_stream_number_.count(&hlo); + return hlo_to_stream_number_.contains(&hlo); } int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const { @@ -98,10 +99,10 @@ int ComputeStreamToAssign( // greedy approach. First, we compute as forbidden_stream_numbers the // streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign // `hlo` a different stream. - std::set forbidden_stream_numbers; + absl::flat_hash_set forbidden_stream_numbers; for (const auto* seen_gemm : seen_gemms) { int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm); - if (!forbidden_stream_numbers.count(stream_num) && + if (!forbidden_stream_numbers.contains(stream_num) && CanRunConcurrently(*seen_gemm, hlo, reachability)) { forbidden_stream_numbers.insert(stream_num); } @@ -109,7 +110,7 @@ int ComputeStreamToAssign( for (int stream_num = 0; stream_num < stream_assignment.StreamCount(); ++stream_num) { - if (!forbidden_stream_numbers.count(stream_num)) { + if (!forbidden_stream_numbers.contains(stream_num)) { return stream_num; } } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index a302b582ede3723acd118d2e4a4bb3efdf7a4d0b..869724db601b2d5e4ed6d3c7bf3e10a748433146 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -65,7 +65,7 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -91,7 +91,7 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -118,7 +118,7 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -152,7 +152,7 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -187,13 +187,13 @@ TEST_F(GpuKernelTilingTest, CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); } -TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { +TEST_F(GpuKernelTilingTest, TransposedInputWithUserReverseNotTiled) { const char *const kHloString = R"( HloModule FusionTransposeWithReverseNotTiled fused_computation.1 { @@ -214,12 +214,203 @@ TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); } +TEST_F(GpuKernelTilingTest, TransposedInputWithUserBitcastNotTiled) { + const char *const kHloString = R"( + HloModule TransposedInputWithUserBitcast + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + ROOT bitcast = f32[20,20]{0,1} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = f32[20,20]{0,1} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, TransposedInputWithoutUnsafeUseTiled) { + const char *const kHloString = R"( + HloModule TwoTransposedInputs + + fused_computation { + param_0 = f32[64,64]{1,0} parameter(0) + param_1 = f32[64,64]{1,0} parameter(1) + bitcast = f32[64,64]{0,1} bitcast(param_0) + copy = f32[64,64]{0,1} copy(param_1) + ROOT tuple = (f32[64,64]{0,1}, f32[64,64]{0,1}) tuple(bitcast, copy) + } + + ENTRY kernel_entry { + parameter.0 = f32[64,64]{1,0} parameter(0) + parameter.1 = f32[64,64]{1,0} parameter(1) + ROOT fusion = (f32[64,64]{0,1}, f32[64,64]{0,1}) + fusion(parameter.0, parameter.1), + kind=kLoop, calls=fused_computation + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, ColumnReductionWithPowerOf2OutputElementsUnrolled) { + const char *const kHloString = R"( + HloModule column_reduce_powerof2 + + reduction { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + ENTRY kernel_entry { + constant0 = f32[] constant(0) + arg1 = f16[1024,512]{1,0} parameter(0) + arg1_conv = f32[1024,512]{1,0} convert(arg1) + ROOT reduce = f32[512]{0} reduce(arg1_conv, constant0), dimensions={0}, to_apply=reduction + })"; + + // Check that two calls to llvm.nvvm.atomic are generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} + +TEST_F(GpuKernelTilingTest, + ColumnReductionWithInputLargerThenReduceInputNotUnrolled) { + const char *const kHloString = R"( + HloModule larger_than_reduce_input_parameter + + reduction22 { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + fused_computation { + constant0 = f32[] constant(0) + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1027,513]{1,0} parameter(1) + arg1.conv = f32[1024,512]{1,0} convert(arg.1) + arg2.conv = f32[1027,513]{1,0} convert(arg.2) + slice2 = f32[1024,512]{1,0} slice(arg2.conv), slice={[2:1026], [1:513]} + add2 = f32[1024,512]{1,0} add(arg1.conv, slice2) + ROOT reduce = f32[512]{0} reduce(add2, constant0), dimensions={0}, + to_apply=reduction22 + } + + ENTRY kernel_entry { + arg1 = f16[1024,512]{1,0} parameter(0) + arg2 = f16[1027,513]{1,0} parameter(1) + ROOT fusion = f32[512]{0} fusion(arg1, arg2), kind=kInput, + calls=fused_computation + })"; + + // Check that one call to llvm.nvvm.atomic is generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} + +TEST_F(GpuKernelTilingTest, ColumnReductionMOFUnrolled) { + const char *const kHloString = R"( + HloModule column_reduce_powerof2_mof + + reduction22 { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + fused_computation { + constant0 = f32[] constant(0) + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1024,512]{1,0} parameter(1) + arg1.conv = f32[1024,512]{1,0} convert(arg.1) + arg2.conv = f32[1024,512]{1,0} convert(arg.2) + reduce1 = f32[512]{0} reduce(arg1.conv, constant0), dimensions={0}, + to_apply=reduction22 + reduce2 = f32[512]{0} reduce(arg2.conv, constant0), dimensions={0}, + to_apply=reduction22 + add = f32[1024,512]{1,0} add(arg1.conv, arg2.conv) + ROOT tuple = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0}) + tuple(reduce1, reduce2, add) + } + + ENTRY kernel_entry { + arg1 = f16[1024,512]{1,0} parameter(0) + arg2 = f16[1024,512]{1,0} parameter(1) + ROOT fusion = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0}) + fusion(arg1, arg2), kind=kInput, calls=fused_computation + })"; + + // Check that four calls to llvm.nvvm.atomic are generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 6b2d76764a077dc6cfa3f9ddc6e525ab330323be..25bad67bab9375559c431466571c62acd0452b01 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -14,17 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/map_util.h" namespace xla { namespace gpu { void ThunkSchedule::AddDependenciesOnTransitiveOperands( const Thunk& thunk, const HloInstruction& operand, - const std::unordered_map& hlo_to_thunk) { - if (hlo_to_thunk.count(&operand)) { + const absl::flat_hash_map& hlo_to_thunk) { + if (hlo_to_thunk.contains(&operand)) { // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency // list if `operand` is assigned to a different stream. As an optimization, // we skip `operand`'s operands because `operand` depends on them already. @@ -48,14 +50,14 @@ ThunkSchedule::ThunkSchedule( const std::vector& hlo_total_order) : thunks_(std::move(thunks)), stream_assignment_(std::move(stream_assignment)) { - std::unordered_map hlo_to_thunk; + absl::flat_hash_map hlo_to_thunk; for (const auto& thunk : *thunks_) { InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); } for (HloInstruction* hlo : hlo_total_order) { - if (hlo_to_thunk.count(hlo)) { - thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo)); + if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) { + thunk_total_order_.push_back(*thunk); } } @@ -106,7 +108,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() { // redundant dependency edge. Array2D last_dependency(stream_count, stream_count, -1); for (const Thunk* dst : thunk_total_order_) { - if (!depends_on_.count(dst)) { + if (!depends_on_.contains(dst)) { continue; } @@ -134,7 +136,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() { const std::list& ThunkSchedule::DependsOn( const Thunk* thunk) const { - if (depends_on_.count(thunk)) { + if (depends_on_.contains(thunk)) { return FindOrDie(depends_on_, thunk); } else { return empty_thunk_list_; diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h index 43b628a1baf0e79a3197f3cfad3547991642eaed..549378debd52417252724a5d8a6f4d24f2ad0369 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -54,7 +56,9 @@ class ThunkSchedule { // Thunks that `thunk` depends on. const std::list& DependsOn(const Thunk* thunk) const; // Whether `thunk` is depended by another thunk. - bool Depended(const Thunk* thunk) const { return depended_by_.count(thunk); } + bool Depended(const Thunk* thunk) const { + return depended_by_.contains(thunk); + } // Delegates to StreamAssignment. int StreamCount() const { return stream_assignment_->StreamCount(); } @@ -75,13 +79,13 @@ class ThunkSchedule { // thunk.hlo_instruction(). void AddDependenciesOnTransitiveOperands( const Thunk& thunk, const HloInstruction& operand, - const std::unordered_map& hlo_to_thunk); + const absl::flat_hash_map& hlo_to_thunk); std::unique_ptr thunks_; std::vector thunk_total_order_; - std::unordered_map> depends_on_; - std::set depended_by_; + absl::flat_hash_map> depends_on_; + absl::flat_hash_set depended_by_; std::list empty_thunk_list_; std::unique_ptr stream_assignment_; diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 9220865867b770eebfb1ada8f31a5d24693a4b8d..4fca981c6a59cdb91a997e6a887fd26472c1a10a 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -199,7 +199,7 @@ Status HeapSimulator::RunComputation( // If the buffer has no users and isn't an entry parameter or output, it // must be a dead value. - if (live_buffers.count(buffer) == 0) { + if (!live_buffers.contains(buffer)) { dead_buffers_to_free.push_back(buffer); } } @@ -225,10 +225,10 @@ Status HeapSimulator::RunComputation( } } // Sort to get a deterministic iteration order. - std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), - [](const BufferValue* x, const BufferValue* y) { - return x->id() < y->id(); - }); + absl::c_sort(operand_buffers_to_free, + [](const BufferValue* x, const BufferValue* y) { + return x->id() < y->id(); + }); // Allocate buffers defined by this instruction. This is the latest point // that we can allocate; right before the buffer is first used. This must @@ -253,7 +253,7 @@ Status HeapSimulator::RunComputation( bool shared = false; if (options_.may_reuse_operand_buffers) { for (const BufferValue* operand_buffer : operand_buffers_to_free) { - if (reused_buffers.count(operand_buffer) != 0) { + if (reused_buffers.contains(operand_buffer)) { continue; } if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && @@ -335,10 +335,9 @@ Status HeapSimulator::RunComputation( to_free.push_back(buffer); } - std::sort(to_free.begin(), to_free.end(), - [](const BufferValue* x, const BufferValue* y) { - return x->id() < y->id(); - }); + absl::c_sort(to_free, [](const BufferValue* x, const BufferValue* y) { + return x->id() < y->id(); + }); for (const BufferValue* buffer : to_free) { VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); @@ -374,15 +373,15 @@ bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const { return true; } return options_.buffers_to_assign != nullptr && - options_.buffers_to_assign->count(buffer) == 0; + !options_.buffers_to_assign->contains(buffer); } // Alloc always calls the underlying heap algorithm. void HeapSimulator::Alloc(const BufferValue* buffer, const HloInstruction* instruction) { - CHECK(allocated_buffers_.count(buffer) == 0) + CHECK(!allocated_buffers_.contains(buffer)) << "Alloc called on allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "Alloc called on freed buffer: " << *buffer; allocated_buffers_.insert(buffer); @@ -411,9 +410,9 @@ void HeapSimulator::Free(const BufferValue* buffer, buffer = group->canonical; } - CHECK(allocated_buffers_.count(buffer) > 0) + CHECK(allocated_buffers_.contains(buffer)) << "Free called on non-allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "Free called on freed buffer: " << *buffer; freed_buffers_.insert(buffer); @@ -433,11 +432,11 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer, const HloInstruction* instruction) { CHECK_LE(size_fn_(*buffer), size_fn_(*shared)) << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared; - CHECK(allocated_buffers_.count(buffer) == 0) + CHECK(!allocated_buffers_.contains(buffer)) << "ShareBuffer called on allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "ShareBuffer called on freed buffer: " << *buffer; - CHECK(freed_buffers_.count(shared) == 0) + CHECK(!freed_buffers_.contains(shared)) << "ShareBuffer called on freed shared buffer: " << *shared; const BufferValue* canonical = nullptr; @@ -452,7 +451,7 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer, } else { // The 'shared' buffer doesn't have a group; it must be the canonical. Add // both 'buffer' and 'shared' to a new group. - CHECK(allocated_buffers_.count(shared) > 0) + CHECK(allocated_buffers_.contains(shared)) << "ShareBuffer called on non-allocated shared buffer: " << *shared; auto group = std::make_shared(); canonical = shared; @@ -596,7 +595,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() { } // Call ops in the run sorted by decreasing size, breaking ties by buffer id. - std::sort(run_.begin(), run_.end(), [](const Op& a, const Op& b) { + absl::c_sort(run_, [](const Op& a, const Op& b) { if (a.size != b.size) { return a.size > b.size; } @@ -866,23 +865,23 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { for (auto& entry : buffer_intervals_) { sorted_buffer_intervals.push_back(entry.second); } - std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(), - [](const BufferInterval& x, const BufferInterval& y) { - if (x.size != y.size) { - return x.size > y.size; - } - if (x.end - x.start != y.end - y.start) { - return x.end - x.start > y.end - y.start; - } - return x.buffer->id() < y.buffer->id(); - }); + absl::c_sort(sorted_buffer_intervals, + [](const BufferInterval& x, const BufferInterval& y) { + if (x.size != y.size) { + return x.size > y.size; + } + if (x.end - x.start != y.end - y.start) { + return x.end - x.start > y.end - y.start; + } + return x.buffer->id() < y.buffer->id(); + }); BufferIntervalTree interval_tree(sorted_buffer_intervals.size()); for (auto& buffer_interval : sorted_buffer_intervals) { auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime( buffer_interval.start, buffer_interval.end); - std::sort( - chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(), + absl::c_sort( + chunks_overlapping_in_time, [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; }); // Find the minimum free chunk that can hold this buffer. diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 9b50f1ca5b5365463f32106fc005ef2c63f2e37a..263b42a29dbb0dbc0fb6eca7968674ff242f45ed 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -229,6 +229,18 @@ message HloScheduleProto { } message HloInputOutputAliasProto { + enum Kind { + // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 + // behavior and missing has_*() APIs. + UNDEFINED_ALIAS = 0; + // An alias setup by the user as must alias. A use setting USER_ALIAS is + // expecting the designed output to be dropped over the given input + // parameter number+index. + USER_ALIAS = 1; + // An alias setup by the compiler as part of its optimizations. + SYSTEM_ALIAS = 2; + } + // The following proto describes a pair of aliased an input // (described by parameter number and a ShapeIndex of the parameter) // and an output (described by a ShapeIndex of the root @@ -249,6 +261,8 @@ message HloInputOutputAliasProto { int64 parameter_number = 2; // ShapeIndex of the parameter instruction. repeated int64 parameter_shape_index = 3; + // The kind of alias to be setup. + Kind kind = 4; } repeated AliasEntryProto entries = 1; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 68094d8907f30202533672376b2199e7b77dc806..e511f1951c5dd07ebb64fa38fd5b7f6a0e87b429 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -117,7 +117,7 @@ class BufferValueMap { for (const auto& pair : buffers_) { buffer_numbers.push_back(pair.first); } - std::sort(buffer_numbers.begin(), buffer_numbers.end()); + absl::c_sort(buffer_numbers); return buffer_numbers; } @@ -176,13 +176,12 @@ class BufferValueMap { const HloValue& value, std::vector* aliased_buffers) { // Get parameter value from an aliased_input object. const auto get_parameter_value = - [this](const std::pair& aliased_input) + [this](const HloInputOutputAliasConfig::Alias& aliased_input) -> const HloValue& { - int64 param_number = aliased_input.first; - const ShapeIndex& param_index = aliased_input.second; return dataflow_.GetUniqueValueAt( - module_->entry_computation()->parameter_instruction(param_number), - param_index); + module_->entry_computation()->parameter_instruction( + aliased_input.parameter_number), + aliased_input.parameter_index); }; // If the value shows up in a root instruction, alias it with parameter @@ -319,7 +318,7 @@ class BufferValueMap { ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. - std::sort(aliased_buffers.begin(), aliased_buffers.end()); + absl::c_sort(aliased_buffers); aliased_buffers.erase( std::unique(aliased_buffers.begin(), aliased_buffers.end()), aliased_buffers.end()); @@ -367,7 +366,7 @@ std::vector HloAliasAnalysis::ComputeBuffersAt( } // Sort and uniquify vector before returning. - std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan); + absl::c_sort(buffers, HloBuffer::IdLessThan); buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end()); return buffers; @@ -430,8 +429,7 @@ Status HloAliasAnalysis::Verify() const { for (const auto& pair : value_to_buffer_) { const HloValue* value = pair.first; const HloBuffer& buffer = *pair.second; - TF_RET_CHECK(std::find(buffer.values().begin(), buffer.values().end(), - value) != buffer.values().end()); + TF_RET_CHECK(absl::c_linear_search(buffer.values(), value)); } for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) { @@ -515,7 +513,7 @@ StatusOr> HloAliasAnalysis::Run( auto& value_set = buffer_map.GetValuesInBuffer(buffer_number); std::vector sorted_values(value_set.begin(), value_set.end()); - std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan); + absl::c_sort(sorted_values, HloValue::IdLessThan); alias_analysis->buffers_.emplace_back(next_id++, sorted_values); for (const HloValue* value : sorted_values) { alias_analysis->value_to_buffer_[value] = @@ -547,16 +545,15 @@ bool HloAliasAnalysis::HasLiveRangeInterference( // tie-break using value ID. The tie-break is necessary because we need a // strict weak order for std::sort. std::vector values = buffer.values(); - std::sort(values.begin(), values.end(), - [&ordering](const HloValue* a, const HloValue* b) { - if (ordering.IsDefinedBefore(*a, *b)) { - return true; - } else if (ordering.IsDefinedBefore(*b, *a)) { - return false; - } else { - return a->id() < b->id(); - } - }); + absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) { + if (ordering.IsDefinedBefore(*a, *b)) { + return true; + } else if (ordering.IsDefinedBefore(*b, *a)) { + return false; + } else { + return a->id() < b->id(); + } + }); // Walk through the ordered vector of values. First verify that the values // are totally ordered with respect to 'ordering', then check that no diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 7e6150e94153cd15463725e862ce1b8593f2c991..b6dbf07959c541bceaa8eda5a0101503970ee832 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -238,13 +238,16 @@ TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); // Cannot alias an output twice. ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -279,13 +282,16 @@ TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); // Cannot alias an output twice. ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -365,9 +371,11 @@ TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index 9c3aa0e64d119c2560f4955d0bcb492519fa52a2..32e48651b30bace4723169935d1f10dd7d7bfec3 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -49,7 +49,7 @@ std::vector HloBuffer::ComputePositions() const { value->positions().end()); } // Remove duplicates and sort positions. - std::sort(positions.begin(), positions.end()); + absl::c_sort(positions); positions.erase(std::unique(positions.begin(), positions.end()), positions.end()); return positions; diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 52ca67afb8eb270c490566ed51514b0b0f499b42..f9b64d12ae83139efa21ca67e565908bd78f9780 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -207,14 +207,14 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( TF_RET_CHECK(instruction->user_count() == 0); TF_RET_CHECK(IsRemovable(instruction)) << "Cannot remove instruction: " << instruction->ToString(); - std::unordered_set removed; + absl::flat_hash_set removed; std::queue worklist; worklist.push(instruction); while (!worklist.empty()) { HloInstruction* item = worklist.front(); worklist.pop(); - if (removed.count(item) != 0 || item->user_count() != 0 || + if (removed.contains(item) || item->user_count() != 0 || item == root_instruction() || !IsRemovable(item) || (item->HasSideEffect() && item != instruction)) { continue; @@ -531,11 +531,10 @@ HloComputation::CreateFromProto( HloInstruction* root = instruction_map.at(proto.root_id()); // Sort the instructions in the proto id's order. - std::sort(instructions.begin(), instructions.end(), - [&](const std::unique_ptr& a, - const std::unique_ptr& b) { - return to_proto_id[a.get()] < to_proto_id[b.get()]; - }); + absl::c_sort(instructions, [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); TF_RETURN_IF_ERROR([&]() -> Status { std::vector parameters_seen(parameter_count); @@ -694,13 +693,14 @@ bool HloComputation::operator==(const HloComputation& other) const { if (this == &other) { return true; } - std::set> visited; + absl::flat_hash_set> + visited; std::function eq = [&visited, &eq](const HloInstruction* a, const HloInstruction* b) { // If are visited but not identical, the recursion should have // been aborted. So, if are visited at this point, they must be // identical. - if (visited.count(std::make_pair(a, b)) > 0) { + if (visited.contains(std::make_pair(a, b))) { return true; } visited.emplace(a, b); @@ -799,17 +799,16 @@ Status HloComputation::AcceptOrdered( absl::Span order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { - TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) - << root->ToString(); + TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString(); } TF_RET_CHECK(order.size() == instruction_count()); - std::unordered_set visited; + absl::flat_hash_set visited; for (const HloInstruction* instruction : order) { VLOG(3) << "Visiting ordered: " << instruction->ToString(); - TF_RET_CHECK(instruction_iterators_.count(instruction) == 1) + TF_RET_CHECK(instruction_iterators_.contains(instruction)) << "Instruction " << instruction->name() << " is not in computation " << name(); - TF_RET_CHECK(visited.count(instruction) == 0) + TF_RET_CHECK(!visited.contains(instruction)) << "Instruction " << instruction->name() << " appears more than once in order"; HloInstruction* mutable_instruction = @@ -847,7 +846,7 @@ std::unique_ptr HloComputation::Clone( return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - context, suffix); + /*extra_parameters=*/{}, context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( @@ -856,7 +855,8 @@ std::unique_ptr HloComputation::CloneWithReplacementPairs( std::unordered_map> replacements; replacements.emplace(std::move(r1)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( @@ -867,7 +867,8 @@ std::unique_ptr HloComputation::CloneWithReplacementPairs( replacements; replacements.emplace(std::move(r1)); replacements.emplace(std::move(r2)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( @@ -880,12 +881,14 @@ std::unique_ptr HloComputation::CloneWithReplacementPairs( replacements.emplace(std::move(r1)); replacements.emplace(std::move(r2)); replacements.emplace(std::move(r3)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, + absl::Span extra_parameters, HloCloneContext* context, const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { @@ -951,6 +954,12 @@ std::unique_ptr HloComputation::CloneWithReplacements( } std::vector> instructions; + // First add the extra parameters to 'instructions'. + for (const auto& instr : extra_parameters) { + CHECK_EQ(instr->opcode(), HloOpcode::kParameter) + << "Only parameter instructions are allowed in 'extra_parameters'"; + instructions.emplace_back(instr->Clone()); + } for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index a0ccbc583f8c409f29d31756fcc1fa1b4af7dc35..e6a1eb89cfdb474f79c184ea0eb77dba8ccd5f03 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -323,11 +323,15 @@ class HloComputation { // that's not already in the computation, it's cloned and added to the new // computation. // + // 'extra_parameters' allows to specify additional parameters that should be + // added to the computation. + // // All relevant instructions are cloned, *including* unique_ptr in the // `replacements` map. std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, + absl::Span extra_parameters = {}, HloCloneContext* context = nullptr, const string& suffix = "clone"); // Convenience overloads for CloneWithReplacements. You want to do diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 0361c87428f6e4c031d95492a5bc782ad388e5b5..251c7bbec418d8c3e8b27277160e608840726996 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -15,8 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include #include +#include +#include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -226,7 +230,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { : computation_(computation) {} Status DefaultAction(HloInstruction* hlo_instruction) override { - EXPECT_EQ(0, visited_set_.count(hlo_instruction)); + EXPECT_FALSE(visited_set_.contains(hlo_instruction)); visited_set_.insert(hlo_instruction); last_visited_ = hlo_instruction; return Status::OK(); @@ -239,7 +243,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { } HloComputation* computation_; - std::set visited_set_; + absl::flat_hash_set visited_set_; int64 finish_visit_calls_ = 0; HloInstruction* last_visited_ = nullptr; }; @@ -491,6 +495,41 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } +TEST_F(HloComputationTest, CloneWithReplacements) { + auto builder = HloComputation::Builder(TestName()); + Shape r0s64 = ShapeUtil::MakeShape(S64, {}); + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + Shape r0u32 = ShapeUtil::MakeShape(U32, {}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "p.0.lhs")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs")); + auto param2 = + builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1")); + auto lt = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param0, param1)); + auto module = CreateNewVerifiedModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/lt)); + std::unordered_map> + replacements; + replacements.emplace(param2, + HloInstruction::CreateParameter(2, r0s32, "p.1")); + auto param3 = HloInstruction::CreateParameter(3, r0u32, "p.2"); + std::vector extra_parameters{param3.get()}; + auto clone = computation->CloneWithReplacements(std::move(replacements), + extra_parameters); + ASSERT_EQ(clone->num_parameters(), 4); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(0)->shape(), r0f32_)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(1)->shape(), r0f32_)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(2)->shape(), r0s32)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(3)->shape(), r0u32)); +} + TEST_F(HloComputationTest, Stringification) { const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index c58d00ff5084d060498f2d6fbbfa8e12207b810e..e7ed858e8c5af83d08863d64a0aba162c75ed5cb 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -35,6 +35,34 @@ limitations under the License. namespace xla { +// Checks whether instr is or transitively contains an instruction that we +// shouldn't fold. +// +// Specifically, we don't fold kRng or kAfterAll instructions: +// +// - kRng is already marked as side-effecting and so is skipped elsewhere, but +// we check for it here. Even kRng weren't side-effecting and took an +// explicit seed, we *still* wouldn't want to constant-fold it, because the +// evaluator's handling of rng is not guaranteed to be identical to any +// particular backend's rng. +// +// - kAfterAll needs to be skipped because a kAfterAll op with no args can +// currently materialize a token "out of thin air". TODO(b/110532604): +// Remove this check once AfterAll requires at least one operand, in which +// case constant folding will be impossible. +static bool IsOrContainsIllegalInstr(const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kAfterAll || + instr->opcode() == HloOpcode::kRng) { + return true; + } + for (const HloComputation* c : instr->called_computations()) { + if (absl::c_any_of(c->instructions(), IsOrContainsIllegalInstr)) { + return true; + } + } + return false; +} + StatusOr HloConstantFolding::Run(HloModule* module) { // Limit the constant folding to 0 iterations to skip folding loops. This // retains the behavior from before while loop support in HloEvaluator and may @@ -52,25 +80,24 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Tuple, AfterAll operation. - // Tuple constants are not directly supported by any backends, hence - // folding Tuple is not useful and would in fact be expanded back into - // kTuple by Algebraic Simplifier. - // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one - // operand in which case constant folding will be impossible and this - // special case is not necessary. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kAfterAll) { - continue; - } // Skip instructions with non-constant operands. if (!hlo_query::AllOperandsAreConstants(*instruction)) { continue; } + // Don't fold Constant, Parameter, and Tuple instructions. Tuple + // constants are not directly supported by any backends, hence folding + // Tuple is not useful and would in fact be expanded back into kTuple by + // Algebraic Simplifier. + // + // (We do allow folding subcomputations that contain these instructions.) + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant || + instruction->opcode() == HloOpcode::kTuple) { + continue; + } + // Broadcasts dramatically increase the size of constants, which is often // detrimental to performance and memory capacity, so do not fold // broadcasts. @@ -79,6 +106,18 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } + // Check for instructions that we can't fold even if they appear inside of + // a subcomputation (e.g. a kCall). + if (IsOrContainsIllegalInstr(instruction)) { + continue; + } + + // Don't constant-fold side-effecting instructions or instructions which + // contain side-effecting instructions. + if (instruction->HasSideEffect()) { + continue; + } + // Don't constant fold unless it's a net positive or the output is small. if (instruction->shape().IsArray()) { int64 elements_in_removed_operands = 0; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 92b748d813c3efef83ef0155f1d5d3c637ce2c57..4bdc980c9ac4fb79cde0242f407aea7057474b27 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -268,5 +268,51 @@ TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { GmockMatch(m::Pad(m::Constant(), m::Constant()))); } +TEST_F(HloConstantFoldingTest, DontFoldSubcomputationContainingAfterAll) { + const char* const kModuleStr = R"( + HloModule test + + Fn { + tok = token[] after-all() + ROOT root = f32[10] iota(), iota_dimension=0 + } + + ENTRY entry { + ROOT call = f32[10] call(), to_apply=Fn + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + HloConstantFolding constant_folding; + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_folding, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(HloConstantFoldingTest, + DontFoldSubcomputationTransitivelyContainingRng) { + const char* const kModuleStr = R"( + HloModule test + + InnerFn { + c0 = f32[] constant(0) + c1 = f32[] constant(1) + ROOT rng = f32[10] rng(c0, c1), distribution=rng_uniform + } + + Fn { + ROOT fusion = f32[10] fusion(), kind=kLoop, calls=InnerFn + } + + ENTRY entry { + ROOT call = f32[10] call(), to_apply=Fn + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + HloConstantFolding constant_folding; + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_folding, module.get())); + EXPECT_FALSE(result); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 1678fba1728e161b1f448079f366e8a68db03414..bb5d21c654c73da257d53e4f8486b2e83019b534 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -105,12 +106,26 @@ StatusOr MakeDynamicSliceHlo( absl::Span slice_sizes) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, start_indices->parent()); + int64 rank = start_indices->shape().dimensions(0); + std::vector scalar_start_indices; + for (int i = 0; i < rank; ++i) { + // TODO(b/118437727): Update callers to provide scalars directly. + auto slice = computation->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}), + start_indices, {i}, {i + 1}, {1})); + scalar_start_indices.push_back( + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {}), + slice))); + } + std::vector scalar_start_indices_shapes( + rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {})); TF_ASSIGN_OR_RETURN( Shape dynamic_slice_shape, ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), slice_sizes)); + operand->shape(), scalar_start_indices_shapes, slice_sizes)); return computation->AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, operand, start_indices, slice_sizes)); + dynamic_slice_shape, operand, scalar_start_indices, slice_sizes)); } StatusOr MakeDynamicUpdateSliceHlo( @@ -119,12 +134,26 @@ StatusOr MakeDynamicUpdateSliceHlo( HloComputation* computation = operand->parent(); CHECK_EQ(computation, update->parent()); CHECK_EQ(computation, start_indices->parent()); + int64 rank = start_indices->shape().dimensions(0); + std::vector scalar_start_indices; + for (int i = 0; i < rank; ++i) { + // TODO(b/118437727): Update callers to provide scalars directly. + auto slice = computation->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}), + start_indices, {i}, {i + 1}, {1})); + scalar_start_indices.push_back( + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {}), + slice))); + } + std::vector scalar_start_indices_shapes( + rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {})); TF_ASSIGN_OR_RETURN( Shape dynamic_update_slice_shape, ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); + operand->shape(), update->shape(), scalar_start_indices_shapes)); return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - dynamic_update_slice_shape, operand, update, start_indices)); + dynamic_update_slice_shape, operand, update, scalar_start_indices)); } StatusOr MakeBroadcastHlo( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index aaa9ec60eb3c4e0159ed40b37d772e0973d306ec..3715e12b4e2baf7bc2149237457c16c3919c5083 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -56,9 +56,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { entry_computation->set_root_instruction(first_1_dims_collapsed); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({3, 4})})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({3, 4})})); CHECK_EQ(result_literal, LiteralUtil::CreateR1({3, 4})); } @@ -77,10 +77,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate( - *module, - {LiteralUtil::CreateR3( - {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, + {{-1, -2}, {-3, -4}, {-5, -6}}})})); CHECK_EQ(result_literal, LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); @@ -101,8 +100,7 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, - {LiteralUtil::CreateR1({9, 10})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({9, 10})})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9, 10}})); } @@ -121,8 +119,7 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, - {LiteralUtil::CreateR1({9, 10})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({9, 10})})); CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{9, 10}}})); } @@ -141,7 +138,7 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9}})); } @@ -160,8 +157,8 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } @@ -180,9 +177,9 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { entry_computation->set_root_instruction(zero_padded_param); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({3, 4})})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({3, 4})})); CHECK_EQ(result_literal, LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } @@ -202,7 +199,7 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } @@ -220,9 +217,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR0(0.0f)})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0.0f)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 1924204df044faa6d55f11c1180e2ecc1f0e9e64..3144a84805454488f417391f40ed6b9e9facc752 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -107,7 +107,7 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( return false; } } - if (!visited.count(user)) { + if (!visited.contains(user)) { stack.push_back(user); } } @@ -256,7 +256,7 @@ bool HloDataflowAnalysis::Phi( input_value_ids.push_back(value->id()); } } - std::sort(input_value_ids.begin(), input_value_ids.end()); + absl::c_sort(input_value_ids); input_value_ids.erase( std::unique(input_value_ids.begin(), input_value_ids.end()), input_value_ids.end()); @@ -271,8 +271,7 @@ bool HloDataflowAnalysis::Phi( if (current_value_defined_here) { VLOG(5) << "current_value_defined_here: " << current_value->ToString(); CHECK(current_value->is_phi()); - auto it = std::find(input_value_ids.begin(), input_value_ids.end(), - current_value->id()); + auto it = absl::c_find(input_value_ids, current_value->id()); if (it != input_value_ids.end()) { input_value_ids.erase(it); } @@ -921,8 +920,7 @@ StatusOr> HloDataflowAnalysis::Run( for (auto& pair : dataflow_analysis->values_) { dataflow_analysis->values_vector_.push_back(&pair.second); } - std::sort(dataflow_analysis->values_vector_.begin(), - dataflow_analysis->values_vector_.end(), HloValue::IdLessThan); + absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan); TF_DCHECK_OK(dataflow_analysis->Verify()); @@ -937,9 +935,7 @@ Status HloDataflowAnalysis::Verify() const { for (const HloValue* value : values()) { for (const HloPosition& position : value->positions()) { const HloValueSet& value_set = GetValueSet(position); - TF_RET_CHECK(std::find(value_set.values().begin(), - value_set.values().end(), - value) != value_set.values().end()) + TF_RET_CHECK(absl::c_linear_search(value_set.values(), value)) << "Value set at position " << position << " does not contain value " << value->ToShortString(); } @@ -954,9 +950,7 @@ Status HloDataflowAnalysis::Verify() const { const HloValueSet& value_set = pair.second; const HloPosition position{instruction, index}; for (const HloValue* value : value_set.values()) { - TF_RET_CHECK(std::find(value->positions().begin(), - value->positions().end(), - position) != value->positions().end()) + TF_RET_CHECK(absl::c_linear_search(value->positions(), position)) << "Value set at position " << position << " unexpectedly contains value " << value->ToShortString(); } @@ -1041,11 +1035,10 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Check if one operand of kAdd fused root is kDot or kConvolution. auto* add = user->fused_expression_root(); auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot; - }); + absl::c_find_if(add->operands(), [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); if (add_operand_it == add->operands().end()) { return false; } @@ -1100,16 +1093,15 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // *) The root instruction of the called computation is element-wise on // 'operand'. const bool found_caller_use = - std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + absl::c_find_if(uses, [user](const HloUse& use) { return use.instruction == user; }) != uses.end(); auto* callee_root = user->to_apply()->root_instruction(); const bool found_elementwise_callee_use = - std::find_if( - uses.begin(), uses.end(), [callee_root](const HloUse& use) { - return use.instruction == callee_root && - callee_root->IsElementwiseOnOperand(use.operand_number); - }) != uses.end(); + absl::c_find_if(uses, [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 1d165ac35f766af857dc0984d3f7012ad945b3c2..888886865b9cd7d09af6d3b5f016b60ccef5facd 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1970,12 +1970,13 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2012,12 +2013,13 @@ TEST_F(DoesNotUseOperandBufferTest, IndirectUses) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2150,17 +2152,17 @@ TEST_F(CanShareOperandBufferWithUserTest, auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "param0")); - auto index = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0}))); - auto ds = builder.AddInstruction( - HloInstruction::CreateDynamicSlice(slice_shape, param, index, {1, 2, 2})); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, param, {zero, zero}, {1, 2, 2})); - auto dus = builder.AddInstruction( - HloInstruction::CreateDynamicUpdateSlice(data_shape, param, ds, index)); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, param, ds, {zero, zero})); BuildModule(builder.Build()); auto fusion = computation_->CreateFusionInstruction( - {dus, ds, index}, HloInstruction::FusionKind::kLoop); + {dus, ds, zero}, HloInstruction::FusionKind::kLoop); RunAnalysis(); EXPECT_TRUE( @@ -2219,12 +2221,13 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2259,12 +2262,13 @@ TEST_F(CanShareOperandBufferWithUserTest, // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape_bf16, convert1, update, starts)); + data_shape_bf16, convert1, update, + std::initializer_list({starts}))); auto convert2 = builder.AddInstruction( HloInstruction::CreateConvert(data_shape, dynamic_update_slice)); @@ -2290,10 +2294,13 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { HloInstruction::CreateParameter(0, data_shape, "data")); auto update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); - auto starts = builder.AddInstruction( - HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto start0 = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "start0")); + auto start1 = builder.AddInstruction( + HloInstruction::CreateParameter(3, starts_shape, "start1")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); + data_shape, data, update, {start0, start1})); BuildModuleAndRunAnalysis(builder.Build()); @@ -2304,7 +2311,9 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { EXPECT_FALSE( dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); EXPECT_FALSE( - dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); + dataflow_analysis_->CanShareOperandBufferWithUser(start0, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(start1, {}, dus, {})); } TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 7d35e251ca21951036336ff1a1eb4aabc87bc5ca..a5a11f09cf4f857b992e5ede3a9dbc5a937ce722 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -65,7 +66,7 @@ StatusOr HloDCE::Run(HloModule* module) { // Now DCE HloComputations. First, collect the computations that are // referenced by some remaining instruction. - std::unordered_set live_computations; + absl::flat_hash_set live_computations; if (HloComputation* entry_computation = module->entry_computation()) { live_computations.insert(entry_computation); } @@ -79,7 +80,7 @@ StatusOr HloDCE::Run(HloModule* module) { // Remove dead computations. for (auto* computation : module->MakeComputationPostOrder()) { - if (live_computations.count(computation) == 0) { + if (!live_computations.contains(computation)) { TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); changed = true; } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 1fa4259a3e42286cbc911907eea563e6ca6f8611..b5d72b386f89568cc3066b2e497be98428d1ed0c 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -43,9 +43,7 @@ class HloDceTest : public HloTestBase { // Returns whether the given instruction exists in the given computation. bool HasInstruction(const HloComputation& computation, const HloInstruction* instruction) { - return std::find(computation.instructions().begin(), - computation.instructions().end(), - instruction) != computation.instructions().end(); + return absl::c_linear_search(computation.instructions(), instruction); } }; diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index c6d02f9f67bb599e496d20fc2acf2e627ed54438..7cdb7f6bdf26241cda4fabbb5ccaf6e6f7de39ce 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -230,10 +230,10 @@ HloDomainMap::MakeNonDomainInstructions( } } // sort instructions according to instructions_order - std::sort(instructions.begin(), instructions.end(), - [&instructions_order](HloInstruction* a, HloInstruction* b) { - return instructions_order.at(a) < instructions_order.at(b); - }); + absl::c_sort(instructions, + [&instructions_order](HloInstruction* a, HloInstruction* b) { + return instructions_order.at(a) < instructions_order.at(b); + }); return instructions; } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index cdb2df24db8b6bafa4c7e98c635435750cdf42ad..b8d95fb24479b375506cfe9a788540ebefafb8c1 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -138,6 +138,11 @@ StatusOr Compare(const Shape& shape, HloOpcode opcode, } // namespace +// Note that unsupported types by the typed visitor does not necessarily imply +// the non-typed HloEvaluator (parent evaluator) would not support them either +// in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent +// type-agnostic evaluator will be able to accept Tuple primitive type, whereas +// HloEvaluatorTypedVisitor cannot. HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { typed_visitors_[PRED] = @@ -198,99 +203,47 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) }); } -template -StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals) { - XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); - - evaluated_.clear(); - arg_literals_.clear(); - for (const auto& literal_ptr : arg_literals) { - arg_literals_.push_back(&*literal_ptr); - } - - TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); - - return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) - .Clone(); -} - -template <> -StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal_ptr : arg_literals) { - arg_literal_ptrs.push_back(&literal_ptr); - } - return Evaluate(module, arg_literal_ptrs); -} - -template StatusOr HloEvaluator::Evaluate( const HloComputation& computation, - absl::Span arg_literals) { + absl::Span arg_literals) { CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); - evaluated_.clear(); - arg_literals_.clear(); - for (const auto& literal_ptr : arg_literals) { - arg_literals_.push_back(&*literal_ptr); + if (arg_literals.size() != computation.num_parameters()) { + return InvalidArgument( + "Expected %d argument%s, but got %d.", computation.num_parameters(), + computation.num_parameters() == 1 ? "" : "s", arg_literals.size()); } - - TF_RETURN_IF_ERROR(computation.Accept(this)); - return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); -} - -template <> -StatusOr HloEvaluator::Evaluate( - const HloComputation& computation, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal_ptr : arg_literals) { - arg_literal_ptrs.push_back(&literal_ptr); + for (int64 i = 0; i < arg_literals.size(); ++i) { + const auto& computation_shape = + computation.parameter_instruction(i)->shape(); + const auto& arg_shape = arg_literals[i]->shape(); + if (!ShapeUtil::Equal(computation_shape, arg_shape)) { + return InvalidArgument( + "Shape mismatch at parameter %d. Computation expected %s, but arg " + "was %s.", + i, ShapeUtil::HumanStringWithLayout(computation_shape), + ShapeUtil::HumanString(arg_shape)); + } } - return Evaluate(computation, arg_literal_ptrs); -} - -template -StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals) { - TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); evaluated_.clear(); arg_literals_.clear(); for (const auto& literal_ptr : arg_literals) { arg_literals_.push_back(&*literal_ptr); } - - // Evaluate operands of Parameter type against the input literals which - // caches the evaluated literal results. - for (const auto operand : instruction->operands()) { - if (operand->opcode() == HloOpcode::kParameter) { - const Literal* input_literal = arg_literals_[operand->parameter_number()]; - VLOG(2) << "Parameter operand evaluated to: " - << input_literal->ToString(); - TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - - evaluated_[operand] = input_literal->Clone(); - } + if (computation.parent()->config().seed()) { + seed_ = computation.parent()->config().seed(); + } else { + std::random_device rd; + seed_ = rd(); } - TF_RETURN_IF_ERROR(Preprocess(instruction)); - TF_RETURN_IF_ERROR(instruction->Visit(this)); - TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).Clone(); -} + engine_ = std::minstd_rand0(seed_); -template <> -StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal : arg_literals) { - arg_literal_ptrs.push_back(&literal); - } - return Evaluate(instruction, arg_literal_ptrs); + TF_RETURN_IF_ERROR(computation.Accept(this)); + return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); } StatusOr HloEvaluator::Evaluate(HloInstruction* instruction) { @@ -408,6 +361,31 @@ Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status HloEvaluator::HandleGetDimensionSize( + HloInstruction* get_dimension_size) { + HloInstruction* operand = get_dimension_size->mutable_operand(0); + int64 dim = get_dimension_size->dimension(); + if (dynamic_dimension_inference_ == nullptr) { + return InvalidArgument( + "Evaluator cannot evaluate get_dimension_size without " + "set_dynamic_dimension_inference."); + } + HloInstruction* dynamic_size = + dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim); + if (dynamic_size != nullptr) { + evaluated_[get_dimension_size] = + GetEvaluatedLiteralFor(dynamic_size).Clone(); + return Status::OK(); + } + + const Shape& shape = get_dimension_size->operand(0)->shape(); + Literal output(ShapeUtil::MakeShape(U32, {})); + output.PopulateWithValue( + static_cast(shape.dimensions(get_dimension_size->dimension()))); + evaluated_[get_dimension_size] = std::move(output); + return Status::OK(); +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; @@ -1127,9 +1105,10 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } HloEvaluator embedded_evaluator; - Literal result = - embedded_evaluator.Evaluate(*computation, arg_literals) - .ConsumeValueOrDie(); + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); + Literal result = embedded_evaluator.Evaluate(*computation, arg_literals) + .ConsumeValueOrDie(); evaluated_[call] = std::move(result); return Status::OK(); @@ -1145,7 +1124,9 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { fusion->fused_instructions_computation()->Clone( /*suffix=*/"clone_with_layout", &context); for (auto* instruction : cloned_fused_computation->instructions()) { - LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); + if (!LayoutUtil::HasLayout(instruction->shape())) { + LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); + } } auto readded_computation = empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation)); @@ -1159,9 +1140,10 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { } HloEvaluator embedded_evaluator; + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); Literal result = - embedded_evaluator - .Evaluate(*readded_computation, arg_literals) + embedded_evaluator.Evaluate(*readded_computation, arg_literals) .ConsumeValueOrDie(); evaluated_[fusion] = std::move(result); @@ -1179,16 +1161,16 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* false_computation = conditional->false_computation(); HloEvaluator embedded_evaluator; + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); Literal result; if (pred.Get({})) { - result = embedded_evaluator - .Evaluate(*true_computation, - {&true_computation_arg}) - .ConsumeValueOrDie(); + result = + embedded_evaluator.Evaluate(*true_computation, {&true_computation_arg}) + .ConsumeValueOrDie(); } else { result = embedded_evaluator - .Evaluate(*false_computation, - {&false_computation_arg}) + .Evaluate(*false_computation, {&false_computation_arg}) .ConsumeValueOrDie(); } @@ -1235,18 +1217,21 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { bool keep_going = true; int64 iteration_count = 0; HloEvaluator cond_evaluator(max_loop_iterations_); + cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_); HloEvaluator loop_body_evaluator(max_loop_iterations_); + loop_body_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", while_hlo->name(), max_loop_iterations_); } TF_ASSIGN_OR_RETURN(auto cond_val, - cond_evaluator.Evaluate(*cond_comp, {&lcv})); + cond_evaluator.Evaluate(*cond_comp, {&lcv})); keep_going = cond_val.GetFirstElement(); if (keep_going) { - TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( - *body_comp, {&lcv})); + TF_ASSIGN_OR_RETURN(auto body_val, + loop_body_evaluator.Evaluate(*body_comp, {&lcv})); VLOG(3) << "Loop iteration result: " << body_val.ToString(); lcv = std::move(body_val); cond_evaluator.ResetVisitStates(); @@ -1297,8 +1282,7 @@ StatusOr EvaluateSortInternal(HloInstruction* sort, // Extract a slice from the keys and values literals that correspond to // exactly the row in dimension 'sort_dim'. std::vector limit_indices(indices.begin(), indices.end()); - std::for_each(limit_indices.begin(), limit_indices.end(), - [](int64& index) { ++index; }); + absl::c_for_each(limit_indices, [](int64& index) { ++index; }); limit_indices[sort_dim] = sort_dim_elements; TF_ASSIGN_OR_RETURN(auto keys_to_sort, keys_literal.Slice(indices, limit_indices) @@ -1455,18 +1439,6 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { return Status::OK(); } -// Explicit instantiation of templatized Evaluate* methods. -// -template StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals); - -template StatusOr HloEvaluator::Evaluate( - const HloComputation& computation, - absl::Span arg_literals); - -template StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals); - namespace { template std::unique_ptr> MatmulArray2DImpl( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 829fc2aba787a057914a038a3b22e19c763517b4..604b861913051574b038bd64a1b9d5ce2e79dbf3 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -43,16 +44,24 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // specified. explicit HloEvaluator(int64 max_loop_iterations = -1); - // Evaluates an HLO module and an array of pointers to literals. - // Returns the evaluated result as a literal if successful. + // Evaluates an HLO module and an array of pointers to literals. Returns the + // evaluated result as a literal if successful. + // // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template + // + // (Dummy template arg is to reduce the overloading priority of one overload + // so that Evaluate(module, {}) resolves unambiguously.) + StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals) { + return Evaluate(*module.entry_computation(), arg_literals); + } + template StatusOr Evaluate(const HloModule& module, - absl::Span arg_literals); + absl::Span arg_literals) { + return Evaluate(*module.entry_computation(), arg_literals); + } // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -70,29 +79,24 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template + // + // (Dummy template arg is to reduce the overloading priority of one overload + // so that Evaluate(module, {}) resolves unambiguously.) StatusOr Evaluate(const HloComputation& computation, - absl::Span arg_literals); - - // Evaluates a single HLO instruction and an array of pointers to literals. - // Return the evaluated result as literal if successful. - // Precondition: - // 1. argument literals correspond to the input instruction's parameters in - // their post-ordering. - // 2. the instruction's operands must be of either Parameter or Constant type. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template - StatusOr Evaluate(HloInstruction* instruction, - absl::Span arg_literals); - - // Evaluates a single HLO instruction with constant operands. - // Returns the evaluated result as literal if successful. - // Precondition: - // 1. all operands of the input instruction are constants. - // 2. the instruction is not a Parameter operation. + absl::Span arg_literals); + template + StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& l : arg_literals) { + arg_literal_ptrs.push_back(&l); + } + return Evaluate(computation, arg_literal_ptrs); + } + + // Gets the value of running a single HLO instruction. + // + // All of the operands to this instruction must be constants. StatusOr Evaluate(HloInstruction* instruction); // Same as Evaluate, except returning false on error and accepts an output @@ -120,6 +124,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs); + void set_dynamic_dimension_inference( + DynamicDimensionInference* dynamic_dimension_inference) { + dynamic_dimension_inference_ = dynamic_dimension_inference; + } + // Enable the fast path for certain operations like dot or convolution. void set_use_fast_path(bool value) { use_fast_path_ = value; } @@ -158,6 +167,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override; + Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant) override; @@ -208,6 +219,29 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleReduce(HloInstruction* reduce) override; + // Unsupported HLOs, note some of them (such as BatchNorm*) are typically + // expanded in a semantic-preserving way into other HLOs by adding exanpsion + // HLO pass to the HLO optimization pass during compilation, which can then be + // handled by the evaluator. + Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override { + return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator."); + }; + Status HandleBatchNormInference( + HloInstruction* batch_norm_inference) override { + return Unimplemented( + "BatchNormInference HLO is unsupported by the evaluator."); + }; + Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { + return Unimplemented( + "BatchNormTraining HLO is unsupported by the evaluator."); + }; + Status HandleInfeed(HloInstruction* infeed) override { + return Unimplemented("Infeed HLO is unsupported by the evaluator."); + }; + Status HandleOutfeed(HloInstruction* outfeed) override { + return Unimplemented("Outfeed HLO is unsupported by the evaluator."); + }; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. @@ -264,6 +298,15 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Max loop iterations to execute with no maximum if negative. int64 max_loop_iterations_; + // Module-level seed handle. + uint64 seed_; + // RNG engine. + std::minstd_rand0 engine_; + + // DynamicDimensionInference is used to evaluate GetDimensionSize, which + // returns the dynamic dimension size of its operand. + DynamicDimensionInference* dynamic_dimension_inference_; + TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 7c56a0cc36f76069613dee608941ce83a8d1d654..674df0016afb5ed6379ded74987ae4a8eef2669a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -51,20 +51,18 @@ namespace { static std::array use_bf16_params{true, false}; -class HloEvaluatorTest : public ::testing::WithParamInterface, - public HloTestBase { - protected: - HloEvaluatorTest() : HloTestBase(), use_bfloat16_(GetParam()) { - evaluator_ = absl::make_unique(); - } +// Test fixture for the HloEvaluator. +// +// In bf16 mode, all f32 shapes are converted to bf16 before running. +class HloEvaluatorTest : public HloTestBase { + public: + HloEvaluatorTest() : use_bfloat16_(false) {} Literal Evaluate(absl::Span arg_literals = {}) { if (use_bfloat16_) { - // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. - auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(m_.get()).ValueOrDie(); + HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*m_->entry_computation(), arg_literals) + return evaluator_.Evaluate(*m_->entry_computation(), arg_literals) .ConsumeValueOrDie(); } @@ -74,16 +72,12 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, Literal EvaluateWithModule( HloModule* module, absl::Span arg_literals = {}) { if (use_bfloat16_) { - // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. - auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(module).ValueOrDie(); + HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*module->entry_computation(), arg_literals) + return evaluator_.Evaluate(*module->entry_computation(), arg_literals) .ConsumeValueOrDie(); } - std::unique_ptr evaluator_; - void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, float aabs = 0) { HloComputation::Builder b(TestName()); @@ -117,16 +111,27 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } - bool use_bfloat16_; + protected: + explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {} + HloEvaluator evaluator_; + + const bool use_bfloat16_; std::unique_ptr m_ = CreateNewVerifiedModule(); }; -#define XLA_TYPED_TEST_P(test_case_name, test_name, test_type1) \ - TEST_P(test_case_name, test_name) +// Lets you write TEST_Ps that run twice, once with and once without bf16. +class HloEvaluatorBf16Test : public ::testing::WithParamInterface, + public HloEvaluatorTest { + protected: + HloEvaluatorBf16Test() : HloEvaluatorTest(/*use_bfloat16=*/GetParam()) {} +}; + +INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorBf16Test, + ::testing::ValuesIn(use_bf16_params)); // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. -TEST_P(HloEvaluatorTest, DoesClamp) { +TEST_P(HloEvaluatorBf16Test, DoesClamp) { auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); @@ -147,7 +152,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { +TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) { auto low = LiteralUtil::CreateR0(0.f); auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); auto high = LiteralUtil::CreateR0(1.f); @@ -170,7 +175,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. -TEST_P(HloEvaluatorTest, DoesSelect) { +TEST_P(HloEvaluatorBf16Test, DoesSelect) { auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); @@ -195,7 +200,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. -TEST_P(HloEvaluatorTest, DoesAdd) { +TEST_F(HloEvaluatorTest, DoesAdd) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); @@ -204,7 +209,7 @@ TEST_P(HloEvaluatorTest, DoesAdd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise and with 2 operands. -TEST_P(HloEvaluatorTest, DoesAnd) { +TEST_P(HloEvaluatorBf16Test, DoesAnd) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {4, 4}}); @@ -213,7 +218,7 @@ TEST_P(HloEvaluatorTest, DoesAnd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. -TEST_P(HloEvaluatorTest, DoesOr) { +TEST_F(HloEvaluatorTest, DoesOr) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-100, 4}}); @@ -222,7 +227,7 @@ TEST_P(HloEvaluatorTest, DoesOr) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. -TEST_P(HloEvaluatorTest, DoesXor) { +TEST_F(HloEvaluatorTest, DoesXor) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-104, 0}}); @@ -231,7 +236,7 @@ TEST_P(HloEvaluatorTest, DoesXor) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise multiply with 2 operands. -TEST_P(HloEvaluatorTest, DoesMultiply) { +TEST_F(HloEvaluatorTest, DoesMultiply) { auto lhs = LiteralUtil::CreateR2({{-1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2( {{std::numeric_limits::min(), 4}, {4, 4}}); @@ -242,14 +247,14 @@ TEST_P(HloEvaluatorTest, DoesMultiply) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. -TEST_P(HloEvaluatorTest, DoesDivideInt64) { +TEST_F(HloEvaluatorTest, DoesDivideInt64) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } -TEST_P(HloEvaluatorTest, DoesDivideDouble) { +TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) { auto lhs = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); auto rhs = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = @@ -260,41 +265,41 @@ TEST_P(HloEvaluatorTest, DoesDivideDouble) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. -TEST_P(HloEvaluatorTest, DoesAbsR2) { +TEST_F(HloEvaluatorTest, DoesAbsR2) { auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesAbsR0) { +TEST_P(HloEvaluatorBf16Test, DoesAbsR0) { auto operand = LiteralUtil::CreateR0(-1.0f); auto expected = LiteralUtil::CreateR0(1.0f); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) { +TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) { auto operand = LiteralUtil::CreateR1({}); auto expected = LiteralUtil::CreateR1({}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesNegateR2) { +TEST_F(HloEvaluatorTest, DoesNegateR2) { auto operand = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {-1, 4}}); auto expected = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {1, -4}}); TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesCosR2) { +TEST_P(HloEvaluatorBf16Test, DoesCosR2) { auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = LiteralUtil::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } -TEST_P(HloEvaluatorTest, DoesSinR2) { +TEST_P(HloEvaluatorBf16Test, DoesSinR2) { auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } -TEST_P(HloEvaluatorTest, DoesNotR2) { +TEST_F(HloEvaluatorTest, DoesNotR2) { auto operand = LiteralUtil::CreateR2({{0, std::numeric_limits::min()}, {-1, std::numeric_limits::max()}}); @@ -305,7 +310,7 @@ TEST_P(HloEvaluatorTest, DoesNotR2) { } // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. -TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { +TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); @@ -335,7 +340,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { } // Verifies Reshape operation is correctly evaluated. -TEST_P(HloEvaluatorTest, DoesReshape) { +TEST_F(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, @@ -361,7 +366,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { } // Verifies Broadcast operation is correctly evaluated. -TEST_P(HloEvaluatorTest, DoesBroadcast) { +TEST_F(HloEvaluatorTest, DoesBroadcast) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto output_literal = LiteralUtil::CreateR3( @@ -377,7 +382,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } -TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { +TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR0(111); auto output_literal = LiteralUtil::CreateR2( @@ -396,7 +401,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } -TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { +TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( @@ -418,7 +423,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { +TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction( @@ -439,7 +444,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { +TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); @@ -458,7 +463,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } -TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { +TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2WithLayout( @@ -491,7 +496,7 @@ PaddingConfig CreatePaddingConfig( return padding_config; } -TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { +TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto operand = LiteralUtil::CreateR2({{}, {}}); HloComputation::Builder b(TestName()); auto operand_instruction = @@ -516,7 +521,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { +TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) { HloComputation::Builder b(TestName()); Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); @@ -551,7 +556,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, NegativePadding2D) { +TEST_P(HloEvaluatorBf16Test, NegativePadding2D) { HloComputation::Builder b(TestName()); // input_array: @@ -593,7 +598,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250))); } -TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { +TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) { HloComputation::Builder b(TestName()); // f32[4,3] { @@ -632,7 +637,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank2AndRank1) { +TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) { HloComputation::Builder b(TestName()); // lhs: @@ -678,7 +683,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank1AndRank2) { +TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -716,7 +721,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank2AndRank2) { +TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -766,7 +771,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank4AndRank4) { +TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) { HloComputation::Builder b(TestName()); auto lhs_array = absl::make_unique>(2, 2, 3, 1); @@ -810,7 +815,7 @@ TEST_P(HloEvaluatorTest, DotRank4AndRank4) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SimpleConv1D) { +TEST_P(HloEvaluatorBf16Test, SimpleConv1D) { HloComputation::Builder b(TestName()); Array3D lhs_array = {{{1, 2, 3}}}; @@ -859,7 +864,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { +TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -922,7 +927,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { +TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) { HloComputation::Builder b(TestName()); // clang-format off @@ -1003,7 +1008,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { +TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) { HloComputation::Builder b(TestName()); // clang-format off @@ -1081,7 +1086,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { +TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1145,7 +1150,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { +TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1210,7 +1215,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, +TEST_P(HloEvaluatorBf16Test, DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) { HloComputation::Builder b(TestName()); @@ -1283,7 +1288,7 @@ TEST_P(HloEvaluatorTest, EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { +TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) { HloComputation::Builder b(TestName()); std::vector input_dims = {1, 2, 2, 4}; std::vector filter_dims = {2, 2, 2, 8}; @@ -1419,7 +1424,7 @@ void BM_ReducePrecisely(int num_iters) { BENCHMARK(BM_ReducePrecisely); -TEST_P(HloEvaluatorTest, ReduceAdd) { +TEST_P(HloEvaluatorBf16Test, ReduceAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1461,7 +1466,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowMax) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) { HloComputation::Builder b(TestName()); // arg: @@ -1512,7 +1517,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxWindowDilation) { HloComputation::Builder b(TestName()); // arg: @@ -1564,7 +1569,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowAdd) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1621,7 +1626,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) { HloComputation::Builder b(TestName()); // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. @@ -1684,7 +1689,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); } -TEST_P(HloEvaluatorTest, StridedSlice) { +TEST_P(HloEvaluatorBf16Test, StridedSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1718,7 +1723,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DynamicSlice) { +TEST_P(HloEvaluatorBf16Test, DynamicSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1734,12 +1739,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); + auto zero = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); - b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, - start_indices, {2, 3})); + b.AddInstruction( + HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1754,7 +1761,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // Verifies that the HloEvaluator's implementation goes along with existing // backends' behavior, although this is not required by the spec. -TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { +TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1770,12 +1777,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2, 1}))); + auto two = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); - b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, - start_indices, {2, 3})); + b.AddInstruction( + HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1788,7 +1797,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { +TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) { HloComputation::Builder b(TestName()); // arg: @@ -1804,15 +1813,17 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); + auto zero = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto update = b.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{-2.0, -3.0}, {-6.0, -7.0}}))); Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - shape, operand, update, start_indices)); + shape, operand, update, {zero, one})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1825,7 +1836,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SetAndGetTuples) { +TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1861,7 +1872,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { +TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1900,7 +1911,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Reverse) { +TEST_P(HloEvaluatorBf16Test, Reverse) { HloComputation::Builder b(TestName()); // Input shape is float[4x3x2x1]. @@ -1953,7 +1964,7 @@ TEST_P(HloEvaluatorTest, Reverse) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { +TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1977,7 +1988,7 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Check that EvaluateWithSubstitutions works if one of the operands to the op // we're evaluating is a constant. -TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { +TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -2000,7 +2011,7 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { const char* hlo_text = R"( HloModule TensorFlowGatherV1 @@ -2024,7 +2035,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { const char* hlo_text = R"( HloModule TensorFlowGatherV2 @@ -2048,7 +2059,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { const char* hlo_text = R"( HloModule TensorFlowGatherMultipleBatchDims @@ -2073,7 +2084,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { const char* hlo_text = R"( HloModule TensorFlowGatherNd @@ -2099,7 +2110,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) { const char* hlo_text = R"( HloModule TensorFlowGatherNd @@ -2126,7 +2137,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { +TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) { const char* hlo_text = R"( HloModule DynamicSlice @@ -2149,7 +2160,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { +TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { const char* hlo_text = R"( HloModule BatchDynamicSlice @@ -2173,7 +2184,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { +TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { const char* hlo_text = R"( HloModule TensorFlowGatherV1 @@ -2195,7 +2206,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { +TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { const string hlo_text = R"( HloModule GatherXd @@ -2220,7 +2231,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { const char* hlo_text = R"( HloModule TensorFlowScatterV1 @@ -2251,7 +2262,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { const char* hlo_text = R"( HloModule TensorFlowScatterV2 @@ -2283,7 +2294,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2315,7 +2326,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2347,7 +2358,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { +TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2381,7 +2392,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2413,7 +2424,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { const char* hlo_text = R"( HloModule TensorFlowScatterMultipleBatchDims @@ -2446,7 +2457,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { const char* hlo_text = R"( HloModule TensorFlowScatterNd @@ -2482,7 +2493,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) { const char* hlo_text = R"( HloModule TensorFlowScatterNdNonDefaultIndexVectorDim @@ -2519,7 +2530,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { +TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { const char* hlo_text = R"( HloModule DynamicUpdateSlice @@ -2551,7 +2562,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { +TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { const char* hlo_text = R"( HloModule BatchDynamicUpdateSlice @@ -2583,7 +2594,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { +TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { const char* hlo_text = R"( HloModule TensorFlowScatter_ZeroDimBounds @@ -2612,7 +2623,7 @@ ENTRY main { operand, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { +TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { const string hlo_text = R"( HloModule Scatter_NoUpdateWindowDims @@ -2645,7 +2656,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { const char* hlo_text = R"( HloModule TensorFlowScatter_NegativeIndices @@ -2680,7 +2691,7 @@ ENTRY main { {&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_OobIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_OobIndices) { const string hlo_text = R"( HloModule BatchDynamicUpdateSlice @@ -2716,7 +2727,7 @@ ENTRY main { {&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { +TEST_F(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { const char* hlo_text = R"( HloModule TensorFlowScatterNd_OobUpdateWindow @@ -2755,7 +2766,7 @@ ENTRY main { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise comparison with 2 bfloat16 operands. -TEST_P(HloEvaluatorTest, DoesCompareBF16) { +TEST_F(HloEvaluatorTest, DoesCompareBF16) { // lhs >= rhs auto lhs = LiteralUtil::CreateR2( {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)}, @@ -2769,7 +2780,7 @@ TEST_P(HloEvaluatorTest, DoesCompareBF16) { std::move(rhs)); } -TEST_P(HloEvaluatorTest, Bf16Reduction) { +TEST_P(HloEvaluatorBf16Test, Bf16Reduction) { const string hlo_text = R"( HloModule Bf16Reduction @@ -2793,7 +2804,7 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); } -TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { +TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) { // Regression test for b/114735354. const string hlo_text = R"( HloModule SliceWithDifferentLayout @@ -2812,7 +2823,7 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } -TEST_P(HloEvaluatorTest, Bitcast) { +TEST_P(HloEvaluatorBf16Test, Bitcast) { // Regression test for b/114735354. constexpr absl::string_view hlo_text_base = R"( HloModule Bitcast @@ -2840,12 +2851,7 @@ ENTRY main { } // Check that s32 under/overflow doesn't trigger a ubsan failure. -TEST_P(HloEvaluatorTest, Int32Overflow) { - // Test not applicable to bf16; only applies to signed integral types. - if (use_bfloat16_) { - return; - } - +TEST_F(HloEvaluatorTest, Int32Overflow) { constexpr absl::string_view hlo_text = R"( HloModule Test @@ -2873,8 +2879,154 @@ ENTRY main { static_cast(pow31 * pow31)); } -INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, - ::testing::ValuesIn(use_bf16_params)); +TEST_F(HloEvaluatorTest, GetDimensionSize) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + size = u32[] parameter(0) + + data = s32[4] parameter(1) + + sum = s32[4] add(data, data) + + ROOT dynamic_size = u32[] get-dimension-size(sum), dimensions={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + + // Set up dynamic parameter binding. + TF_CHECK_OK(m_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + TF_ASSERT_OK_AND_ASSIGN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(m_.get())); + + evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference); + Literal size_arg = LiteralUtil::CreateR0(3); + Literal data_arg = LiteralUtil::CreateR1({1, 2, 3, 4}); + + Literal actual = Evaluate({&size_arg, &data_arg}); + + EXPECT_EQ(actual.GetFirstElement(), static_cast(3)); +} + +// Check that we get a useful error if we pass inputs of the wrong shape. +TEST_F(HloEvaluatorTest, EvaluateWithWrongInputShapes) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + p0 = s32[1] parameter(0) + ROOT sum = s32[1] add(p0, p0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal input_wrong_shape = LiteralUtil::CreateR1({0, 1}); + + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_, {&input_wrong_shape}) + .status() + .error_message(), + "Shape mismatch at parameter 0. Computation expected s32[1]{0}, " + "but arg was s32[2]."); + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_->entry_computation(), {&input_wrong_shape}) + .status() + .error_message(), + "Shape mismatch at parameter 0. Computation expected s32[1]{0}, " + "but arg was s32[2]."); +} + +// Check that we get a useful error if we pass too many or too few inputs. +TEST_F(HloEvaluatorTest, EvaluateWithWrongNumberOfInputs) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + p0 = s32[1] parameter(0) + ROOT sum = s32[1] add(p0, p0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal input = LiteralUtil::CreateR1({0}); + + EXPECT_EQ( + HloEvaluator().Evaluate(*m_, {&input, &input}).status().error_message(), + "Expected 1 argument, but got 2."); + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_->entry_computation(), {&input, &input}) + .status() + .error_message(), + "Expected 1 argument, but got 2."); +} + +TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule FusionInputLayout + + fused_computation { + param_0 = f32[20,20]{0,1} parameter(0) + ROOT bitcast = f32[20,20]{1,0} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{0,1} parameter(0) + ROOT fusion = f32[20,20]{1,0} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); +} + +TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule FusionOutputLayout + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + ROOT bitcast = f32[20,20]{0,1} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = f32[20,20]{0,1} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); +} + +TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule MOFusionOutputLayout + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + bitcast = f32[20,20]{0,1} bitcast(param_0) + ROOT tuple = (f32[20,20]{0,1}) tuple(bitcast) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = (f32[20,20]{0,1}) fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual_tuple = Evaluate({&args[0]}); + std::vector actual_literals = actual_tuple.DecomposeTuple(); + EXPECT_TRUE( + absl::c_equal(args[0].data(), actual_literals[0].data())); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index f95a3c4ef9a88198722d54c8a9a4ef4017d23a2c..698b177310476f4a5bca11b423525eb20e7fcf98 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -917,7 +917,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - return std::fmin(high, std::fmax(value, low)); + if (std::isnan(low) || std::isnan(high)) { + return static_cast(NAN); + } + return static_cast( + std::fmin(high, std::fmax(value, low))); }; TF_ASSIGN_OR_RETURN( parent_->evaluated_[clamp], @@ -1406,10 +1410,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto operand = dynamic_slice->operand(0); auto start_indices = dynamic_slice->operand(1); auto result_shape = dynamic_slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), - dynamic_slice->dynamic_slice_sizes())); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), + Cast(dynamic_slice)->index_shapes(), + dynamic_slice->dynamic_slice_sizes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1418,33 +1424,39 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { primitive_util::IsIntegralType(start_indices->shape().element_type())); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); switch (start_indices->shape().element_type()) { case S32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case S64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case U32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case U64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; default: LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " @@ -1464,7 +1476,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( auto inferred_return_shape, ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); + operand->shape(), update->shape(), + Cast(dynamic_update_slice) + ->index_shapes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1475,33 +1489,39 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); switch (start_indices->shape().element_type()) { case S32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case S64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case U32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case U64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; default: LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " @@ -1538,7 +1558,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Literal computed_result = - embedded_evaluator.Evaluate(*computation, arg_literals) + embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on // the same computation. @@ -1635,8 +1655,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Extract a slice from the literal that corresponds to exactly the // row in dimension 'sort_dim'. std::vector limit_indices(indices.begin(), indices.end()); - std::for_each(limit_indices.begin(), limit_indices.end(), - [](int64& index) { ++index; }); + absl::c_for_each(limit_indices, [](int64& index) { ++index; }); limit_indices[sort_dim] = sort_dim_elements; TF_ASSIGN_OR_RETURN(auto row_to_sort, keys_literal.Slice(indices, limit_indices) @@ -1799,7 +1818,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](Literal& literal) { return &literal; }); TF_ASSIGN_OR_RETURN(Literal computed_result, - embedded_evaluator.Evaluate( + embedded_evaluator.Evaluate( *function, embedded_operands_ptrs)); // Clear visit states so that we can use the evaluator again on // the same computation. @@ -1915,8 +1934,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val_literal.Set({}, *selected_val); Literal computed_result = embedded_evaluator - .Evaluate( - *select, {&selected_val_literal, &curr_val_literal}) + .Evaluate(*select, + {&selected_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); bool selected = !computed_result.Get({}); if (selected) { @@ -1937,9 +1956,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { scattered_literal.Set({}, scattered); Literal computed_result = embedded_evaluator - .Evaluate( - *scatter, - {&source_literal_scatter, &scattered_literal}) + .Evaluate(*scatter, + {&source_literal_scatter, &scattered_literal}) .ConsumeValueOrDie(); result.Set(operand_index, computed_result.Get({})); // Clear visit states so that the we can use the evaluator again @@ -2013,8 +2031,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(result_val); Literal computed_result = embedded_evaluator - .Evaluate( - *function, {&result_val_literal, &curr_val_literal}) + .Evaluate(*function, + {&result_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again @@ -2376,9 +2394,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(updates.Get(update_index)); Literal updated_result = embedded_evaluator - .Evaluate( - *scatter->to_apply(), - {&result_value_literal, &update_value_literal}) + .Evaluate(*scatter->to_apply(), + {&result_value_literal, &update_value_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on the // same computation. @@ -2617,7 +2634,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Double not supported for reduce precision"); + return InvalidArgument("Double is not supported for reduce precision"); } template < @@ -2632,12 +2649,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleReducePrecision(reduce_precision); } - template ::value || - std::is_same::value || - std::is_integral::value || - std::is_floating_point::value>::type* = nullptr> + template < + typename NativeT, + typename std::enable_if< + std::is_same::value || + std::is_same::value || + std::is_integral::value || is_complex_t::value || + std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); @@ -2668,12 +2686,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template ::value || - std::is_same::value || - std::is_integral::value || - std::is_floating_point::value)>::type* = nullptr> + template < + typename NativeT, + typename std::enable_if< + !(std::is_same::value || + std::is_same::value || + std::is_integral::value || is_complex_t::value || + std::is_floating_point::value)>::type* = nullptr> Status HandleIota(HloInstruction* iota) { return UnsupportedTypeError(iota); } @@ -2681,6 +2700,103 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleIota(iota); } + template ::value || + std::is_floating_point::value)>::type* = nullptr> + Status HandleRng(HloInstruction* random) { + return UnsupportedTypeError(random); + } + template ::value)>::type* = nullptr> + Status HandleRng(HloInstruction* random) { + RandomDistribution distribution = random->random_distribution(); + const auto result_shape = random->shape(); + Literal result(result_shape); + + switch (distribution) { + case RNG_UNIFORM: { + const Literal& low = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& high = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + std::uniform_real_distribution generator( + low.Get({}), high.Get({})); + + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + return generator(parent_->engine_); + })); + break; + } + case RNG_NORMAL: { + const Literal& mean = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& stddev = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + std::normal_distribution generator(mean.Get({}), + stddev.Get({})); + + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + return generator(parent_->engine_); + })); + break; + } + default: + return UnimplementedStrCat("The distribution ", + RandomDistribution_Name(distribution), + " is not implemented."); + } + parent_->evaluated_[random] = std::move(result); + return Status::OK(); + } + template ::value)>::type* = + nullptr> + Status HandleRng(HloInstruction* random) { + RandomDistribution distribution = random->random_distribution(); + const auto result_shape = random->shape(); + Literal result(result_shape); + + switch (distribution) { + case RNG_UNIFORM: { + const Literal& low = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& high = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + // Note std::uniform_int_distribution assumes interval is closed, i.e., + // [low, high], but we want [low, high) instead. Hence high-1 is used as + // the upper range. + std::uniform_int_distribution generator( + low.Get({}), high.Get({}) - 1); + + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + return static_cast(generator(parent_->engine_)); + })); + break; + } + case RNG_NORMAL: { + return Unimplemented( + "Normal distribution is not supported for integral types."); + } + default: + return UnimplementedStrCat("The distribution ", + RandomDistribution_Name(distribution), + " is not implemented."); + } + parent_->evaluated_[random] = std::move(result); + return Status::OK(); + } + Status HandleRng(HloInstruction* random) override { + return HandleRng(random); + } + private: // Creates a vector of multipliers which can be used to create a linear index // into shape. @@ -2740,12 +2856,27 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr DynamicSlice(const Literal& operand_literal, - const Literal& start_indices_literal, - const Shape& result_shape) { - auto start_indices_typed = start_indices_literal.data(); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); + StatusOr DynamicSlice( + const Literal& operand_literal, + absl::Span start_indices, + const Shape& result_shape) { + std::vector start; + // TODO(b/118437727): Remove the R1 code-path. Note that to distinguish + // between the cases, this currently assumes there is at least 1 index. That + // is wrong in the general case, because for scalar indices, if the operand + // is scalar, then there are no indices. This problem with resolve itself. + const HloInstruction* first_index = start_indices[0]; + if (first_index->shape().rank() == 1) { + auto start_indices_typed = + parent_->GetEvaluatedLiteralFor(first_index).data(); + start = std::vector(start_indices_typed.begin(), + start_indices_typed.end()); + } else { + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); + } + } // Clamp the start indices so the slice is in-bounds w.r.t the operand. for (int64 i = 0; i < start.size(); ++i) { @@ -2771,14 +2902,28 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr DynamicUpdateSlice(const Literal& operand_literal, - const Literal& update_literal, - const Literal& start_indices_literal) { + StatusOr DynamicUpdateSlice( + const Literal& operand_literal, const Literal& update_literal, + absl::Span start_indices) { auto result = operand_literal.Clone(); - auto start_indices_typed = start_indices_literal.data(); const auto rank = result.shape().rank(); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); + std::vector start; + // TODO(b/118437727): Remove the R1 code-path. Note that to distinguish + // between the cases, this currently assumes there is at least 1 index. That + // is wrong in the general case, because for scalar indices, if the operand + // is scalar, then there are no indices. This problem with resolve itself. + const HloInstruction* first_index = start_indices[0]; + if (first_index->shape().rank() == 1) { + auto start_indices_typed = + parent_->GetEvaluatedLiteralFor(first_index).data(); + start = std::vector(start_indices_typed.begin(), + start_indices_typed.end()); + } else { + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); + } + } // Clamp the update start indices so the slice is in-bounds w.r.t the // operand. for (int64 i = 0; i < rank; ++i) { diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc index c919dbd82d3668c477bf37074f1d56f8cb7d9506..862b2029718bbd802b69d789b66683a4edfa2367 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -25,7 +26,9 @@ namespace xla { namespace { -StatusOr ReplaceGetSize(HloInstruction* instr) { +StatusOr ReplaceGetSize( + HloInstruction* instr, + const DynamicDimensionInference* dynamic_dimension_inference) { if (instr->opcode() != HloOpcode::kGetDimensionSize) { return false; } @@ -36,10 +39,18 @@ StatusOr ReplaceGetSize(HloInstruction* instr) { instr->operand(0)->shape(), instr->dimension())); TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)); TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32)); - uint32 size = instr->operand(0)->shape().dimensions(instr->dimension()); - HloInstruction* new_instr = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + HloInstruction* operand = instr->mutable_operand(0); + int64 dim = instr->dimension(); + HloInstruction* dynamic_size = + dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); + if (dynamic_size != nullptr) { + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); + } else { + uint32 size = instr->operand(0)->shape().dimensions(dim); + HloInstruction* new_instr = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + } return true; } @@ -48,10 +59,13 @@ StatusOr ReplaceGetSize(HloInstruction* instr) { StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { bool changed = false; HloProto proto; + TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, + DynamicDimensionInference::Run(module)); *proto.mutable_hlo_module() = module->ToProto(); for (auto* computation : module->computations()) { for (auto instruction : computation->instructions()) { - TF_ASSIGN_OR_RETURN(bool replaced, ReplaceGetSize(instruction)); + TF_ASSIGN_OR_RETURN(bool replaced, + ReplaceGetSize(instruction, &inference)); changed = changed || replaced; } } diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h index 30f44c23a835b3bcc935caaa917e040e07c4e703..9aa79fe66b665c48ec871c4188e44ba2056de3ad 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h @@ -21,7 +21,9 @@ limitations under the License. namespace xla { -// Pass to replace a kGetDimensionSize instruction with a constant instruction. +// Pass to replace a kGetDimensionSize instruction with a hlo instruction +// representing the dynamic size if the dimension is dynamic, otherwise a +// constant instruction representing the static size. class HloGetDimensionSizeRewriter : public HloModulePass { public: absl::string_view name() const override { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index c6eaead8dd9c3dbf19bce34989e8b1f0f60468cc..4c7f5e9e7dfb12a8cb699bdf397eab21983342a1 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -24,9 +24,9 @@ limitations under the License. #include #include #include -#include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -380,7 +380,7 @@ class HloDotDumper { // Each HloInstruction dumped gets a monotically-increasing node ID. This // must start at 1, because that's where graphviz's accounting starts. int64 next_node_id_ = 1; - std::unordered_map node_ids_; + absl::flat_hash_map node_ids_; // The "root" tag doesn't have an associated HloInstruction pointer, so we // need to store it outside the map. @@ -397,7 +397,7 @@ class HloDotDumper { // Each HloComputation that's emitted gets a monotonically-increasing ID. int64 next_cluster_id_ = 1; - std::unordered_map cluster_ids_; + absl::flat_hash_map cluster_ids_; // Edges to print from Footer(). Edges come at the end because graphviz is // unhappy if an edge from a subcomputation to a node in the outer computation @@ -407,7 +407,7 @@ class HloDotDumper { // When coloring by sharding information, we track the sharding string // representation to color association, by round-robin the color schemes. - std::unordered_map + absl::flat_hash_map sharding_colors_; int64 next_shard_color_ = 0; }; @@ -561,8 +561,8 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { } // Show the subcomputation if we're showing any of its members. - return std::any_of( - subcomp->instructions().begin(), subcomp->instructions().end(), + return absl::c_any_of( + subcomp->instructions(), [&](const HloInstruction* instr) { return filter_.Show(instr); }); } @@ -735,15 +735,14 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { const int kMinUsersToOmit = 3; return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() && !instr->IsFused() && - std::count_if(instr->users().begin(), instr->users().end(), - [&](const HloInstruction* user) { - return filter_.Show(user); - }) > kMinUsersToOmit && - std::all_of(instr->users().begin(), instr->users().end(), - [&](const HloInstruction* user) { - return !filter_.Show(user) || - user->opcode() == HloOpcode::kGetTupleElement; - }); + absl::c_count_if(instr->users(), + [&](const HloInstruction* user) { + return filter_.Show(user); + }) > kMinUsersToOmit && + absl::c_all_of(instr->users(), [&](const HloInstruction* user) { + return !filter_.Show(user) || + user->opcode() == HloOpcode::kGetTupleElement; + }); } string HloDotDumper::DumpInstruction(const HloInstruction* instr) { @@ -900,12 +899,11 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { // the same color as a parameter. Unless the merged-in parameter is a // parameter to a fusion node that is bound to a constant -- these aren't // "real" parameters from the user's perspective. - if (std::any_of(instr->operands().begin(), instr->operands().end(), - [&](const HloInstruction* operand) { - return operand->opcode() == HloOpcode::kParameter && - ShouldMergeIntoUsers(operand) && - TryGetFusionParameterConstant(operand) == nullptr; - })) { + if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) { + return operand->opcode() == HloOpcode::kParameter && + ShouldMergeIntoUsers(operand) && + TryGetFusionParameterConstant(operand) == nullptr; + })) { return parameter_color; } @@ -1286,7 +1284,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, int64 radius) { // First, find the neighborhood of nodes with distance from root <= radius. // These nodes are our initial set of "normal" nodes. - std::unordered_map nodes; + absl::flat_hash_map nodes; std::deque> worklist; worklist.push_back({root, 0}); while (!worklist.empty()) { @@ -1307,7 +1305,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, // are not interesting to the graph at hand. if (instr == root || instr->opcode() != HloOpcode::kTuple) { for (const HloInstruction* operand : instr->operands()) { - if (!nodes.count(operand)) { + if (!nodes.contains(operand)) { worklist.push_back({operand, depth + 1}); } } @@ -1335,7 +1333,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, continue; } for (const HloInstruction* user : instr->users()) { - if (!nodes.count(user)) { + if (!nodes.contains(user)) { worklist.push_back({user, depth + 1}); } } @@ -1344,7 +1342,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, auto is_displayed = [&](const HloInstruction* instr) { // Constants are displayed inline with their users; they're never omitted. // Nodes in subcomputations are always shown. - return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant || + return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant || instr->parent() != root->parent(); }; @@ -1355,12 +1353,11 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, NodeFilterResult& filter_result = kv.second; const auto& operands = instr->operands(); - if (std::any_of(operands.begin(), operands.end(), is_displayed) && - !std::all_of(operands.begin(), operands.end(), is_displayed)) { + if (absl::c_any_of(operands, is_displayed) && + !absl::c_all_of(operands, is_displayed)) { // Mark nodes with some operands omitted appropriately. filter_result = kSomeOperandsOmitted; - } else if (!operands.empty() && - std::none_of(operands.begin(), operands.end(), is_displayed)) { + } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) { // Mark nodes with *all* operands omitted appropriately. filter_result = kOmitNodeOperands; } @@ -1368,8 +1365,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their // users made it into the graph. if (filter_result == kSomeUsersOmitted && - std::all_of(instr->users().begin(), instr->users().end(), - is_displayed)) { + absl::c_all_of(instr->users(), is_displayed)) { filter_result = kNormalNode; } } diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc index 6e1597fd03db0a78aa560340b7b9b64fe500df0c..b01c00121b3363630b83a1e49d0027a66f3a9e1a 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -17,22 +17,34 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" namespace xla { + +bool HloInputOutputAliasConfig::OutputHasAlias( + const ShapeIndex& output_index) const { + return alias_.element(output_index).has_value(); +} + Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + const ShapeIndex& param_index, + AliasKind kind) { + TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias) + << kind; TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) << absl::StrCat("Tring to set up alias at ", output_index.ToString(), " which is an invalid index for shape ", ShapeUtil::HumanString(alias_.shape())); + TF_RET_CHECK(param_number >= 0) << param_number; + TF_RET_CHECK(!OutputHasAlias(output_index)) + << "Output index " << output_index << " already has an alias setup"; // Output can't be aliased with multiple parameters. TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat( "Trying to set up output alias for param %lld at %s but failed: output " "index %s is already aliased with param %lld at %s", param_number, param_index.ToString(), output_index.ToString(), - alias_.element(output_index)->first, - alias_.element(output_index)->second.ToString()); + alias_.element(output_index)->parameter_number, + alias_.element(output_index)->parameter_index.ToString()); (*alias_.mutable_element(output_index)) = - std::make_pair(param_number, param_index); + Alias(kind, param_number, param_index); VLOG(4) << "Set up alias between output index " << output_index.ToString() << " and parameter " << param_index << " at index " << param_index.ToString(); @@ -42,15 +54,24 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { HloInputOutputAliasProto result; alias_.ForEachElement( - [&](const ShapeIndex& index, - const absl::optional>& data) { + [&](const ShapeIndex& index, const absl::optional& data) { if (data) { HloInputOutputAliasProto::AliasEntryProto entry; + switch (data->kind) { + case AliasKind::kUserAlias: + entry.set_kind(HloInputOutputAliasProto::USER_ALIAS); + break; + case AliasKind::kSystemAlias: + entry.set_kind(HloInputOutputAliasProto::SYSTEM_ALIAS); + break; + default: + LOG(FATAL) << "Unknown alias kind " << data->kind; + } for (int64 i : index) { entry.add_output_shape_index(i); } - entry.set_parameter_number(data->first); - for (int64 i : data->second) { + entry.set_parameter_number(data->parameter_number); + for (int64 i : data->parameter_index) { entry.add_parameter_shape_index(i); } result.add_entries()->Swap(&entry); @@ -66,14 +87,18 @@ StatusOr HloInputOutputAliasConfig::CreateFromProto( proto.entries()) { ShapeIndex output_index(entry.output_shape_index().begin(), entry.output_shape_index().end()); - int64 param_number = entry.parameter_number(); ShapeIndex param_index(entry.parameter_shape_index().begin(), entry.parameter_shape_index().end()); + // Handle backward compatibility with existing protos, which only knew of + // system aliases. + AliasKind kind = AliasKind::kSystemAlias; + if (entry.kind() == HloInputOutputAliasProto::USER_ALIAS) { + kind = AliasKind::kUserAlias; + } TF_RETURN_IF_ERROR( - result.SetUpAlias(output_index, param_number, param_index)); + result.SetUpAlias(output_index, param_number, param_index, kind)); } - return result; } @@ -81,45 +106,44 @@ string HloInputOutputAliasConfig::ToString() const { std::vector pieces; pieces.push_back("HloInputOutputAliasConfig"); - ForEachAlias([&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) { + const char* kind = alias.kind == AliasKind::kUserAlias ? "USER" : "SYSTEM"; pieces.push_back(absl::StrFormat( - " OutputIndex %s is aliased with parameter %lld at %s:", - output_index.ToString(), param_number, param_index.ToString())); + " OutputIndex %s is aliased (kind=%s) with parameter %lld at %s:", + output_index.ToString(), kind, alias.parameter_number, + alias.parameter_index.ToString())); }); - return absl::StrJoin(pieces, "\n"); } -bool HloInputOutputAliasConfig::ParameterHasAlias( +HloInputOutputAliasConfig::AliasKind +HloInputOutputAliasConfig::ParameterAliasKind( int64 param_number, const ShapeIndex& param_index) const { - bool output = false; + AliasKind kind = AliasKind::kNoAlias; alias_.ForEachElement( - [&](const xla::ShapeIndex&, - absl::optional> alias) { - if (alias && alias->first == param_number && - alias->second == param_index) { - output = true; + [&](const xla::ShapeIndex&, absl::optional alias) { + if (alias && alias->parameter_number == param_number && + alias->parameter_index == param_index) { + kind = alias->kind; } }); - return output; + return kind; } absl::optional HloInputOutputAliasConfig::GetAliasedOutput( int64 param_number, const ShapeIndex& param_index) const { absl::optional output; alias_.ForEachElement( - [&](const xla::ShapeIndex& output_index, - absl::optional> alias) { - if (alias && alias->first == param_number && - alias->second == param_index) { + [&](const xla::ShapeIndex& output_index, absl::optional alias) { + if (alias && alias->parameter_number == param_number && + alias->parameter_index == param_index) { output = output_index; } }); return output; } -absl::optional> +absl::optional HloInputOutputAliasConfig::GetAliasedParameter( const ShapeIndex& output_index) const { CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)); @@ -128,10 +152,9 @@ HloInputOutputAliasConfig::GetAliasedParameter( void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { alias_.ForEachElement( - [&](const ShapeIndex& output_index, - absl::optional> aliased) { + [&](const ShapeIndex& output_index, absl::optional aliased) { if (aliased) { - fn(output_index, aliased->first, aliased->second); + fn(output_index, *aliased); } }); } @@ -139,10 +162,9 @@ void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { Status HloInputOutputAliasConfig::ForEachAliasWithStatus( AliasFnWithStatus fn) const { return alias_.ForEachElementWithStatus( - [&](const ShapeIndex& output_index, - absl::optional> aliased) { + [&](const ShapeIndex& output_index, absl::optional aliased) { if (aliased) { - TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second)); + TF_RETURN_IF_ERROR(fn(output_index, *aliased)); } return Status::OK(); }); @@ -158,20 +180,19 @@ Status HloInputOutputAliasConfig::Verify( param_has_seen.emplace_back(param->shape()); } return ForEachAliasWithStatus([&](const ShapeIndex& output_index, - int64 param_number, - const ShapeIndex& param_index) -> Status { + const Alias& alias) -> Status { const HloInstruction* root = entry->root_instruction(); - TF_RET_CHECK(0 <= param_number); - TF_RET_CHECK(entry->num_parameters() > param_number); + TF_RET_CHECK(0 <= alias.parameter_number); + TF_RET_CHECK(entry->num_parameters() > alias.parameter_number); const Shape& param_shape = - entry->parameter_instruction(param_number)->shape(); + entry->parameter_instruction(alias.parameter_number)->shape(); const Shape& output_shape = root->shape(); - TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, alias.parameter_index)); TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index)); const Shape& param_subshape = - ShapeUtil::GetSubshape(param_shape, param_index); + ShapeUtil::GetSubshape(param_shape, alias.parameter_index); const Shape& output_subshape = ShapeUtil::GetSubshape(output_shape, output_index); TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape)); @@ -182,19 +203,20 @@ Status HloInputOutputAliasConfig::Verify( "Expected aliased input %lld at index %s and output at index %s to " "have the same size. Input sub-shape is %s with size %lld, output " "sub-shape is %s with size %lld", - param_number, param_index.ToString(), output_index.ToString(), + alias.parameter_number, alias.parameter_index.ToString(), + output_index.ToString(), ShapeUtil::HumanStringWithLayout(param_subshape), size_func(param_subshape), ShapeUtil::HumanStringWithLayout(output_subshape), size_func(output_subshape)); } - // Check each param_number and param_index pair only show up once. No - // input can be aliased with output buffers. - TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false); - - *(param_has_seen[param_number].mutable_element(param_index)) = true; - + // Check each alias.parameter_number and alias.parameter_index pair only + // show up once. No input can be aliased with output buffers. + TF_RET_CHECK(param_has_seen[alias.parameter_number].element( + alias.parameter_index) == false); + *(param_has_seen[alias.parameter_number].mutable_element( + alias.parameter_index)) = true; return Status::OK(); }); } diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h index 439676b1546c4af7f781fb80bccffd5248309b0f..b0b71dece81b561f492767db8c1ccbe3fde442d4 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_tree.h" @@ -31,6 +32,28 @@ class HloModule; // parameter index in the entry computation. class HloInputOutputAliasConfig { public: + // The kind of aliases which can be set. A kUserAlias is one setup at + // compilation time by the user, and has to be respected. A kSystemAlias one + // might be setup by the compiler, if it decides it is convenient to do so. + enum AliasKind { + kNoAlias, + kUserAlias, + kSystemAlias, + }; + + // Defines the alias information for a given output buffer. A given output + // buffer shape index can refer only to one parameter+index. + struct Alias { + Alias(AliasKind kind, int64 parameter_number, ShapeIndex parameter_index) + : kind(kind), + parameter_number(parameter_number), + parameter_index(std::move(parameter_index)) {} + + AliasKind kind; + int64 parameter_number; + ShapeIndex parameter_index; + }; + HloInputOutputAliasConfig() = default; explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {} @@ -40,12 +63,22 @@ class HloInputOutputAliasConfig { // Sets up alias config from `output_index` to `param_index` at // `param_number`. Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index); + const ShapeIndex& param_index, AliasKind kind); + + // Returns the kind of alias for the given parameter number and parameter + // index. If no alias exists, AliasKind::kNoAlias is returned. + AliasKind ParameterAliasKind(int64 param_number, + const ShapeIndex& param_index) const; // Returns true if the given parameter is aliased with one of the output // buffers. bool ParameterHasAlias(int64 param_number, - const ShapeIndex& param_index) const; + const ShapeIndex& param_index) const { + return ParameterAliasKind(param_number, param_index) != AliasKind::kNoAlias; + } + + // Checks whether the provided output index has already been aliased. + bool OutputHasAlias(const ShapeIndex& output_index) const; // (De)Serializes an HloInputOutoutAliasConfig to/from an // HloInputOutoutAliasProto. @@ -63,19 +96,17 @@ class HloInputOutputAliasConfig { // Returns the number of parameter and index of the parameter buffer that the // given output buffer index is aliased with. A nullopt is returned if there // is no parameter is aliased with the specific output. - absl::optional> GetAliasedParameter( + absl::optional GetAliasedParameter( const ShapeIndex& output_index) const; using AliasFn = - std::function; + std::function; // Iterates through each aliased output and input. void ForEachAlias(AliasFn fn) const; using AliasFnWithStatus = - std::function; + std::function; // Verifies that the given config is valid for the given module. // Specifically, the config's input and output should be in-bound and size of @@ -90,9 +121,10 @@ class HloInputOutputAliasConfig { private: // A ShapeTree which indicates the list of buffers that's expected to be // aliased. The key on this shape tree represents the output index. The value - // is a pair of parameter number and index into the buffer. If the value is - // nullopt, it means there is no parameter aliasing for this output. - ShapeTree>> alias_; + // is an Alias data structure which defines the input parameter coordinates. + // If the value is nullopt, it means there is no parameter aliasing for this + // output. + ShapeTree> alias_; }; std::ostream& operator<<(std::ostream& out, diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc index aeb9b0fdc8b6cca87731a2d4aae25120af6c3215..a46a107723de30176241aae01b268a8c10d991d3 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -45,11 +45,12 @@ class HloInputOutputAliasConfigTest : public HloTestBase { EXPECT_TRUE(aliased_output); EXPECT_EQ(aliased_output.value(), output_index); - absl::optional> aliased_param = + absl::optional aliased_param = config.GetAliasedParameter(output_index); EXPECT_TRUE(aliased_param); - EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index)); + EXPECT_EQ(aliased_param->parameter_number, param_number); + EXPECT_EQ(aliased_param->parameter_index, param_index); } void expect_not_aliased(const ShapeIndex& output_index, int64 param_number, @@ -60,11 +61,12 @@ class HloInputOutputAliasConfigTest : public HloTestBase { EXPECT_FALSE(aliased_output && aliased_output == output_index); - absl::optional> aliased_param = + absl::optional aliased_param = config.GetAliasedParameter(output_index); - EXPECT_FALSE(aliased_param && aliased_param->first == param_number && - aliased_param->second == param_index); + EXPECT_FALSE(aliased_param && + aliased_param->parameter_number == param_number && + aliased_param->parameter_index == param_index); } }; @@ -84,8 +86,10 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); expect_aliased(/*output_index=*/{0}, /*param_number=*/1, /*param_index=*/{}, config); @@ -114,11 +118,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{0})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{1})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); expect_aliased(/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, config); @@ -149,11 +157,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); @@ -176,8 +188,10 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); @@ -200,11 +214,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, - /*param_index=*/{})); + ASSERT_IS_NOT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 29155359872d061544159ea3374554a35939ff5e..029d170317cbd119a62989ad6ae115abde1977e9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -82,15 +83,14 @@ StatusOr> HloInstruction::CreateFromProto( return computation_map.at(proto.called_computation_ids(index)); }; - TF_RET_CHECK(std::all_of( - proto.operand_ids().begin(), proto.operand_ids().end(), - [&instruction_map](int64 id) { return instruction_map.contains(id); })) + TF_RET_CHECK( + absl::c_all_of(proto.operand_ids(), + [&](int64 id) { return instruction_map.contains(id); })) << proto.name() << " instruction contains invalid operand id(s)"; - TF_RET_CHECK(std::all_of( - proto.called_computation_ids().begin(), - proto.called_computation_ids().end(), - [&computation_map](int64 id) { return computation_map.contains(id); })) + TF_RET_CHECK( + absl::c_all_of(proto.called_computation_ids(), + [&](int64 id) { return computation_map.contains(id); })) << proto.name() << " instruction references invalid computation id(s)"; Shape shape(proto.shape()); @@ -452,13 +452,50 @@ StatusOr> HloInstruction::CreateFromProto( CreatePad(shape, operands(0), operands(1), proto.padding_config()); break; case HloOpcode::kDynamicSlice: { - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "DynamicSlice instruction should have 2 operands but sees " - << proto.operand_ids_size(); std::vector slice_sizes(proto.dynamic_slice_sizes_size()); absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); - instruction = - CreateDynamicSlice(shape, operands(0), operands(1), slice_sizes); + TF_RET_CHECK(proto.operand_ids_size() >= 1) + << "DynamicSlice instruction should have at least 1 operands but " + "sees " + << proto.operand_ids_size(); + if (proto.operand_ids_size() == 2 && operands(1)->shape().rank() == 1) { + // TODO(b/118437727): Old form, remove this path. + instruction = + CreateDynamicSlice(shape, operands(0), operands(1), slice_sizes); + } else { + // New form + auto expected_operands = 1 + operands(0)->shape().rank(); + TF_RET_CHECK(proto.operand_ids_size() == expected_operands) + << "DynamicSlice instruction should have " << expected_operands + << " operands, but has " << proto.operand_ids_size(); + const auto& operand_vector = all_operands(); + instruction = CreateDynamicSlice( + shape, operands(0), absl::MakeSpan(operand_vector).subspan(1), + slice_sizes); + } + break; + } + case HloOpcode::kDynamicUpdateSlice: { + TF_RET_CHECK(proto.operand_ids_size() >= 2) + << "DynamicUpdateSlice instruction should have at least 2 operands " + "but sees " + << proto.operand_ids_size(); + if (proto.operand_ids_size() == 3 && operands(2)->shape().rank() == 1) { + // TODO(b/118437727): Old form, remove this path. + instruction = CreateDynamicUpdateSlice(shape, operands(0), operands(1), + operands(2)); + } else { + // New form + auto expected_operands = 2 + operands(0)->shape().rank(); + TF_RET_CHECK(proto.operand_ids_size() == expected_operands) + << "DynamicUpdateSlice instruction should have " + << expected_operands << " operands, but has " + << proto.operand_ids_size(); + const auto& operand_vector = all_operands(); + instruction = + CreateDynamicUpdateSlice(shape, operands(0), operands(1), + absl::MakeSpan(operand_vector).subspan(2)); + } break; } case HloOpcode::kGather: { @@ -917,6 +954,14 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand, shape, operand, start_indices, slice_sizes); } +/* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, + absl::Span slice_sizes) { + return absl::make_unique( + shape, operand, start_indices, slice_sizes); +} + /* static */ std::unique_ptr HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction* operand, @@ -926,6 +971,14 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, shape, operand, update, start_indices); } +/* static */ std::unique_ptr +HloInstruction::CreateDynamicUpdateSlice( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices) { + return absl::make_unique( + shape, operand, update, start_indices); +} + /* static */ std::unique_ptr HloInstruction::CreateConcatenate( const Shape& shape, absl::Span operands, int64 dimension) { @@ -1382,9 +1435,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateReshape(shape, new_operands[0]); break; case HloOpcode::kDynamicUpdateSlice: - CHECK_EQ(new_operands.size(), 3); clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], - new_operands[2]); + new_operands.subspan(2)); break; case HloOpcode::kTuple: clone = CreateTuple(new_operands); @@ -1546,12 +1598,10 @@ HloInstruction::InstructionVector HloInstruction::unique_operands() const { Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); - if (std::find(control_successors_.begin(), control_successors_.end(), - instruction) == control_successors_.end()) { + if (!absl::c_linear_search(control_successors_, instruction)) { control_successors_.push_back(instruction); - TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(), - instruction->control_predecessors_.end(), - this) == instruction->control_predecessors_.end()); + TF_RET_CHECK( + !absl::c_linear_search(instruction->control_predecessors_, this)); instruction->control_predecessors_.push_back(this); } return Status::OK(); @@ -1800,7 +1850,7 @@ void HloInstruction::RemoveUser(HloInstruction* user) { user_set_.erase(set_it); // This is linear in the number of the users, but a vector provides a stable // iteration order and much faster traversal. - auto vec_it = std::find(users_.begin(), users_.end(), user); + auto vec_it = absl::c_find(users_, user); CHECK(vec_it != users_.end()); users_.erase(vec_it); } @@ -1818,8 +1868,7 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, RemoveUser(user); - TF_RET_CHECK( - std::count(user->operands_.begin(), user->operands_.end(), this) >= 0); + TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0); std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); @@ -1832,6 +1881,16 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, Status HloInstruction::ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand) { + auto old_operand = operand(operand_num); + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), + new_operand->shape())) + << old_operand->shape() << " is not compatible with " + << new_operand->shape(); + return ReplaceOperandWithDifferentShape(operand_num, new_operand); +} + +Status HloInstruction::ReplaceOperandWithDifferentShape( + int64 operand_num, HloInstruction* new_operand) { TF_RET_CHECK(operand_num >= 0); TF_RET_CHECK(operand_num < operand_count()); HloInstruction* old_operand = mutable_operand(operand_num); @@ -1839,17 +1898,12 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, return Status::OK(); } - TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), - new_operand->shape())) - << old_operand->shape() << " is not compatible with " - << new_operand->shape(); operands_[operand_num] = new_operand; VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with " << new_operand->name() << ", was " << old_operand->name(); - if (std::find(operands_.begin(), operands_.end(), old_operand) == - operands_.end()) { + if (!absl::c_linear_search(operands_, old_operand)) { old_operand->RemoveUser(this); } new_operand->AddUser(this); @@ -1857,6 +1911,14 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, } Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { + TF_RET_CHECK( + ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape())) + << shape() << " is not compatible with " << new_producer->shape(); + return ReplaceAllUsesWithDifferentShape(new_producer); +} + +Status HloInstruction::ReplaceAllUsesWithDifferentShape( + HloInstruction* new_producer) { bool new_producer_is_user = false; for (HloInstruction* user : users()) { if (user == new_producer) { @@ -1881,7 +1943,8 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { AddUser(new_producer); } if (parent_ && parent_->root_instruction() == this) { - parent_->set_root_instruction(new_producer); + parent_->set_root_instruction(new_producer, + /*accept_different_shape=*/true); } return Status::OK(); @@ -2824,7 +2887,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { } return UseKind::kReuse; case HloOpcode::kDynamicUpdateSlice: - // Dynamic-update-slice reuses only operand 2 (start_indices). + // Dynamic-update-slice reuses only start_indices. if (i == 0 || i == 1) { return UseKind::kUse; } @@ -2877,10 +2940,10 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding) { bool has_interior_padding = - std::any_of(padding.dimensions().begin(), padding.dimensions().end(), - [](const PaddingConfig::PaddingConfigDimension& dim) { - return dim.interior_padding() != 0; - }); + absl::c_any_of(padding.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.interior_padding() != 0; + }); return StrJoin( padding.dimensions(), "x", [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 36e1ab49319a3e28143ef4d08888c68c86fbcf62..789ace6a263b37e08858b534028c26554105eb30 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -556,15 +556,26 @@ class HloInstruction { // Creates a slice instruction, where the first operand is sliced by // start indices specified in the second operand, and by size specified in // 'slice_sizes'. + ABSL_DEPRECATED("Use span-of-indices form instead") static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes); + // Same as above, but expects a span of scalar start indices. + static std::unique_ptr CreateDynamicSlice( + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, + absl::Span slice_sizes); // Creates a dynamic update slice instruction, which updates a slice // of 'operand' with 'update' and 'start_indices'. + ABSL_DEPRECATED("Use span-of-indices form instead") static std::unique_ptr CreateDynamicUpdateSlice( const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices); + // Same as above, but expects a span of scalar start indices. + static std::unique_ptr CreateDynamicUpdateSlice( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices); // Creates a concatenate instruction, where the operands are concatenated on // the provided dimension. @@ -928,11 +939,16 @@ class HloInstruction { // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); - // Replaces the specified operand with new_operand. + // Replaces the specified operand with new_operand. The old and new operands + // must have compatible shapes ignoring floating-point precision. // // This function does NOT remove duplicated operands even if this instruction // is a fusion, so that the existing operand numbers do not change. - Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); + Status ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand); + + // Same as ReplaceOperandWith(), but new_operand can have a different shape. + Status ReplaceOperandWithDifferentShape(int64 operand_num, + HloInstruction* new_operand); // Replaces all uses of this instruction with the new producer. If // new_producer is a user of this instruction then new_producer remains a use @@ -941,10 +957,16 @@ class HloInstruction { // If this instruction is the root of its computation, sets the computation's // root to new_producer. // + // The new producer must have a compatible shape ignoring floating-point + // precision. + // // If a user is a fusion instruction, this function will remove any duplicated // operands of it which could be created due to this replacement. Status ReplaceAllUsesWith(HloInstruction* new_producer); + // Same as ReplaceAllUsesWith, but new_producer can have a different shape. + Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer); + // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when // complete. If ignore_control_predecessors is true, instructions only diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 8048e332cb57747286758b75773b29ba154aa888..35f031f29a7aca8db7ebe2fbcfdcebb7a778d703 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -55,13 +56,13 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { } Status HandleParameter(HloInstruction* parameter) override { - EXPECT_EQ(0, count_.count(parameter)); + EXPECT_FALSE(count_.contains(parameter)); count_[parameter] = GetCountsForNode(parameter); return Status::OK(); } Status HandleConstant(HloInstruction* constant) override { - EXPECT_EQ(0, count_.count(constant)); + EXPECT_FALSE(count_.contains(constant)); count_[constant] = GetCountsForNode(constant); return Status::OK(); } @@ -69,25 +70,25 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { Status HandleAdd(HloInstruction* add) override { auto lhs = add->operand(0); auto rhs = add->operand(1); - EXPECT_EQ(0, count_.count(add)); - EXPECT_GT(count_.count(lhs), 0); - EXPECT_GT(count_.count(rhs), 0); + EXPECT_FALSE(count_.contains(add)); + EXPECT_TRUE(count_.contains(lhs)); + EXPECT_TRUE(count_.contains(rhs)); count_[add] = GetCountsForNode(add); return Status::OK(); } Status HandleNegate(HloInstruction* negate) override { auto operand = negate->operand(0); - EXPECT_EQ(0, count_.count(negate)); - EXPECT_GT(count_.count(operand), 0); + EXPECT_FALSE(count_.contains(negate)); + EXPECT_TRUE(count_.contains(operand)); count_[negate] = GetCountsForNode(negate); return Status::OK(); } Status HandleMap(HloInstruction* map) override { - EXPECT_EQ(0, count_.count(map)); + EXPECT_FALSE(count_.contains(map)); for (HloInstruction* arg : map->operands()) { - EXPECT_GT(count_.count(arg), 0); + EXPECT_TRUE(count_.contains(arg)); } count_[map] = GetCountsForNode(map); return Status::OK(); @@ -96,9 +97,9 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); - EXPECT_EQ(0, count_.count(reduce)); - EXPECT_GT(count_.count(arg), 0); - EXPECT_GT(count_.count(init_value), 0); + EXPECT_FALSE(count_.contains(reduce)); + EXPECT_TRUE(count_.contains(arg)); + EXPECT_TRUE(count_.contains(init_value)); count_[reduce] = GetCountsForNode(reduce); return Status::OK(); } @@ -128,7 +129,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { } // Counters for HLOs. Maps HLO to a NumOpsAndUsers. - std::unordered_map count_; + absl::flat_hash_map count_; }; TEST_F(HloInstructionTest, BasicProperties) { @@ -137,7 +138,7 @@ TEST_F(HloInstructionTest, BasicProperties) { EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32)); EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32)); - EXPECT_EQ(0, parameter->operand_count()); + EXPECT_FALSE(parameter->operand_count()); } TEST_F(HloInstructionTest, UserWithTwoOperands) { @@ -981,9 +982,9 @@ TEST_F(HloInstructionTest, FunctionVisitor) { module->AddEntryComputation(builder.Build()); int visit_num = 0; - std::unordered_map visit_order; + absl::flat_hash_map visit_order; EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) { - EXPECT_EQ(0, visit_order.count(inst)); + EXPECT_FALSE(visit_order.contains(inst)); visit_order[inst] = visit_num; visit_num++; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 1f37b284a2606f5256a02510a93bcec8664b0eb6..c24bf41ff8e0853d97f512be98603133c0d4f63d 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -42,11 +42,9 @@ using absl::StrJoin; bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, const HloInstruction* operand) { std::vector operand_indices = instruction->OperandIndices(operand); - return std::all_of( - operand_indices.begin(), operand_indices.end(), - [instruction](int64 operand_index) { - return instruction->IsElementwiseOnOperand(operand_index); - }); + return absl::c_all_of(operand_indices, [instruction](int64 operand_index) { + return instruction->IsElementwiseOnOperand(operand_index); + }); } string PrecisionConfigToString(const PrecisionConfig& precision_config) { @@ -814,8 +812,7 @@ std::vector HloSliceInstruction::ExtraAttributesToStringImpl( std::vector bounds; bounds.reserve(slice_starts_.size()); const bool omit_stride = - std::all_of(slice_strides_.begin(), slice_strides_.end(), - [](int64 stride) { return stride == 1; }); + absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; }); for (int i = 0; i < slice_starts_.size(); ++i) { string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); bounds.push_back( @@ -1051,8 +1048,7 @@ HloInstruction* HloFusionInstruction::AddFusionOperand( void HloFusionInstruction::MergeFusionInstruction( HloFusionInstruction* instruction_to_merge) { - CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != - operands().end()); + CHECK(absl::c_linear_search(operands(), instruction_to_merge)); // Clone the instruction from which to merge fused instructions. std::unique_ptr cloned = instruction_to_merge->Clone(); HloFusionInstruction* cloned_fusion = @@ -1219,8 +1215,8 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( // corresponding fused parameter instruction. Renumber parameters as // necessary to make parameter numbers consistent with their index in the // fused_parameter_ vector. - bool in_operand_list = std::find(operands().begin(), operands().end(), - instruction_to_fuse) != operands().end(); + bool in_operand_list = + absl::c_linear_search(operands(), instruction_to_fuse); CHECK(add_output || in_operand_list); if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { // We assume all uses of a kTuple operation are GTE ops, not another @@ -1324,7 +1320,7 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( if (newly_created_tuple_instr) { HloInstruction* new_instr = parent()->AddInstruction( HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); - TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); + TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr)); } int64 index = tuple_elements.size(); if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { @@ -2007,6 +2003,18 @@ HloDynamicSliceInstruction::HloDynamicSliceInstruction( AppendOperand(start_indices); } +HloDynamicSliceInstruction::HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, + absl::Span slice_sizes) + : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape), + dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { + AppendOperand(operand); + for (HloInstruction* index : start_indices) { + AppendOperand(index); + } +} + HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) @@ -2016,6 +2024,17 @@ HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( AppendOperand(start_indices); } +HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices) + : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) { + AppendOperand(operand); + AppendOperand(update); + for (HloInstruction* index : start_indices) { + AppendOperand(index); + } +} + HloInstructionProto HloDynamicSliceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); for (int64 slice_size : dynamic_slice_sizes_) { @@ -2041,9 +2060,14 @@ std::unique_ptr HloDynamicSliceInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 2); - return absl::make_unique( - shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); + if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) { + // TODO(b/118437727): Old form, remove this path. + return absl::make_unique( + shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); + } else { + return absl::make_unique( + shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_); + } } HloGatherInstruction::HloGatherInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index ca212c7f2c98f75ceefc14b7fbc2a1f530c06cf7..e6111cfb57581589070b8e34556bdfe8239b4fd3 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1183,7 +1183,22 @@ class HloDynamicIndexInstruction : public HloInstruction { public: explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) : HloInstruction(opcode, shape) {} - virtual int64 index_operand_number() const = 0; + virtual int64 first_index_operand_number() const = 0; + + // Returns a subspan of operands which represent the start indices. + absl::Span index_operands() const { + return absl::MakeSpan(operands()).subspan(first_index_operand_number()); + } + + // Returns the shapes of the index operands. + std::vector index_shapes() const { + std::vector shapes; + auto indices = index_operands(); + for (const HloInstruction* index : indices) { + shapes.push_back(index->shape()); + } + return shapes; + } }; class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { @@ -1192,6 +1207,10 @@ class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes); + explicit HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, + absl::Span slice_sizes); // Old methods kept for smooth subclassing transition END. // Returns the size of the slice in the given dimension for a dynamic // slice node. @@ -1204,7 +1223,7 @@ class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; - int64 index_operand_number() const override { return 1; } + int64 first_index_operand_number() const override { return 1; } private: std::vector ExtraAttributesToStringImpl( @@ -1229,8 +1248,11 @@ class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices); + explicit HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices); - int64 index_operand_number() const override { return 2; } + int64 first_index_operand_number() const override { return 2; } }; class HloGatherInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 5bf055f3c012fef687cdc275d62efdf2d4cd5e5c..e14bcfa7f67e736a4d04f5b236fb2df02cf150e0 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" @@ -36,11 +37,11 @@ namespace xla { namespace { using Worklist = std::deque; -using Workset = std::unordered_set; +using Workset = absl::flat_hash_set; void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, Workset* workset) { - if (workset->count(instruction) == 0) { + if (!workset->contains(instruction)) { worklist->push_back(instruction); workset->insert(instruction); VLOG(3) << "ADD instruction: " << instruction->name(); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index fe8371384c0fa3900a9022f101ff0b296439cf16..258f918f47a313b4b89fb260457b1b119dc16177 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -107,11 +107,10 @@ HloComputation* HloModule::AddEntryComputation( } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { - auto it = - std::find_if(computations_.begin(), computations_.end(), - [&to_remove](const std::unique_ptr& comp) { - return comp.get() == to_remove; - }); + auto it = absl::c_find_if( + computations_, [&to_remove](const std::unique_ptr& comp) { + return comp.get() == to_remove; + }); TF_RET_CHECK(it->get() == to_remove); computations_.erase(it); return Status::OK(); @@ -304,11 +303,10 @@ StatusOr> HloModule::CreateFromProto( auto module = absl::make_unique(proto.name(), module_config); // Sort the computations in the proto id's order. - std::sort(computations.begin(), computations.end(), - [&](const std::unique_ptr& a, - const std::unique_ptr& b) { - return to_proto_id[a.get()] < to_proto_id[b.get()]; - }); + absl::c_sort(computations, [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); // Add sorted computations to the module. for (auto& computation : computations) { @@ -392,15 +390,12 @@ namespace { // Returns whether `hlo` is used outside the given subcomputation. // `instructions_in_subcomputation` is the instruction set of the given // subcomputation. -bool IsUsedOutsideSubcomputation( - const HloInstruction& hlo, - const std::unordered_set& instructions_in_subcomputation) { - for (HloInstruction* user : hlo.users()) { - if (!instructions_in_subcomputation.count(user)) { - return true; - } - } - return false; +bool IsUsedOutsideSubcomputation(const HloInstruction& hlo, + const absl::flat_hash_set& + instructions_in_subcomputation) { + return absl::c_any_of(hlo.users(), [&](HloInstruction* user) { + return !instructions_in_subcomputation.contains(user); + }); } } // anonymous namespace @@ -411,9 +406,9 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( // A map from original instructions to their counterparts in the new outlined // function. - std::unordered_map outlined_instructions; + absl::flat_hash_map outlined_instructions; // A set that contains all instructions to be outlined. - std::unordered_set instruction_set_to_outline( + absl::flat_hash_set instruction_set_to_outline( instructions_to_outline.begin(), instructions_to_outline.end()); std::vector arguments; std::vector outputs; @@ -502,7 +497,7 @@ std::vector HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the // module). - std::set nonroot_computations; + absl::flat_hash_set nonroot_computations; for (auto& computation : computations_) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -515,19 +510,19 @@ std::vector HloModule::MakeComputationPostOrder() const { // Keep track of computations which have already been added to the post // order. This prevents duplication as an embedded computation may be called // from two different root computations. - std::set added_computations; + absl::flat_hash_set added_computations; std::vector post_order; for (auto& computation : computations_) { - if (nonroot_computations.count(computation.get()) == 0) { + if (!nonroot_computations.contains(computation.get())) { for (HloComputation* embedded_computation : computation->MakeEmbeddedComputationsList()) { - if (added_computations.count(embedded_computation) == 0) { + if (!added_computations.contains(embedded_computation)) { post_order.push_back(embedded_computation); added_computations.insert(embedded_computation); } } // Root computations should only be encountered once. - CHECK_EQ(0, added_computations.count(computation.get())); + CHECK(!added_computations.contains(computation.get())); post_order.push_back(computation.get()); added_computations.insert(computation.get()); } diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index e535b7d74943943069b4d795cf999a3b1e963360..f6e2866204955ac024c2b6f972de449cc3df4c15 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -38,9 +38,7 @@ class HloModuleDceTest : public HloTestBase { // Returns whether the given instruction exists in the given computation. bool HasInstruction(const HloComputation& computation, const HloInstruction* instruction) { - return std::find(computation.instructions().begin(), - computation.instructions().end(), - instruction) != computation.instructions().end(); + return absl::c_linear_search(computation.instructions(), instruction); } // Returns whether the while instruction with name 'while_name' in diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 80f8ca2226b601d80af70395e485eb73246ddacb..47734bc55cc00d605f4e318400be88639450343c 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -199,7 +199,7 @@ bool HloModuleGroupMetadata::IsChannelInstruction( } bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const { - return companion_set_index_.count(hlo) > 0; + return companion_set_index_.contains(hlo); } bool HloModuleGroupMetadata::InstructionCommunicates( @@ -510,7 +510,7 @@ Status HloModuleGroupMetadata::CheckCommunicatingInstruction( HloComputation* computation = instruction->parent(); const HloModule* module = computation->parent(); if (module->entry_computation() == computation || - tracked_instructions_.count(computation) > 0) { + tracked_instructions_.contains(computation)) { return Status::OK(); } return FailedPrecondition("channel is used in disallowed computation"); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 5cd0e38f3883a325dffe586f6ec46be9f6b799ab..3ed95c10504141139d83eb8679a0b8144b15ad0d 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -178,7 +178,7 @@ class HloModuleGroupMetadata { // Precondition: IsCompanionWhile(instruction) is true. const std::vector& Companions( const HloInstruction* instruction) const { - CHECK_EQ(companion_set_index_.count(instruction), 1); + CHECK(companion_set_index_.contains(instruction)); return companion_set(companion_set_index_.at(instruction)); } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index ca6a154809be46d6a0305c29e2b89219de408019..0cec61c257bb84e467290fb52ec9063a32ed558d 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -367,7 +367,7 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const HloInstruction* a, const HloInstruction* b) const { CHECK_EQ(a->parent(), b->parent()); // If either instruction is not in the order, then 'a' and 'b' are unordered. - if (order_position_.count(a) == 0 || order_position_.count(b) == 0) { + if (!order_position_.contains(a) || !order_position_.contains(b)) { return false; } return order_position_.at(a) < order_position_.at(b); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 56848ce0e8558132d9706c5134db6a4dcef6bdfe..638396308c2a9c1f20e47f78b594d54f07c0c4e5 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1171,24 +1171,39 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { + LocTy loc = lexer_.GetLoc(); + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.empty()) { + return Error(loc, "Expected at least one operand."); + } + if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) && + operands.size() != 1 + operands[0]->shape().rank()) { + return Error(loc, "Wrong number of operands."); + } instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice( - shape, /*operand=*/operands[0], /*start_indices=*/operands[1], + shape, /*operand=*/operands[0], + /*start_indices=*/absl::MakeSpan(operands).subspan(1), *dynamic_slice_sizes)); break; } case HloOpcode::kDynamicUpdateSlice: { - if (!ParseOperands(&operands, /*expected_size=*/3) || - !ParseAttributes(attrs)) { + LocTy loc = lexer_.GetLoc(); + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.size() < 2) { + return Error(loc, "Expected at least two operands."); + } + if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) && + operands.size() != 2 + operands[0]->shape().rank()) { + return Error(loc, "Wrong number of operands."); + } instruction = builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, /*operand=*/operands[0], /*update=*/operands[1], - /*start_indices=*/operands[2])); + /*start_indices=*/absl::MakeSpan(operands).subspan(2))); break; } case HloOpcode::kTranspose: { @@ -2731,7 +2746,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( } auto is_unique = [](string str) -> bool { - std::sort(str.begin(), str.end()); + absl::c_sort(str); return std::unique(str.begin(), str.end()) == str.end(); }; diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index bc1a736766abc2fe0746b34b97fd0a5a17de462d..76b8a5bc117e653b3b2c09c7c95f26f1d4f27c7b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -566,12 +566,26 @@ ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) - ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258} } +)" +}, +// Dynamic slice with scalar indices +{ +"DynamicSliceScalarIndices", +R"(HloModule DynamicSlice_module + +ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] { + %original_parameter = s32[2,2,258]{2,1,0} parameter(0) + %constant = s32[] constant(0) + %start_index = s32[] parameter(1) + ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258} +} + )" }, // Dynamic update slice { "DynamicUpdateSlice", -R"(HloModule DynamicUpdateSlice_module +R"(HloModule DynamicSlice_module ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] { %input = s32[1,1,25,1]{3,2,1,0} parameter(0) @@ -580,6 +594,23 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices) } +)" +}, +// Dynamic update slice with scalar indices +{ +"DynamicUpdateSliceScalarIndex", +R"(HloModule DynamicUpdateSlice_module + +ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] { + %input = s32[1,1,25,1]{3,2,1,0} parameter(0) + %update = s32[1,1,2,1]{3,2,1,0} parameter(1) + %start_index.0 = s32[] parameter(2) + %start_index.1 = s32[] parameter(3) + %start_index.2 = s32[] parameter(4) + %start_index.3 = s32[] parameter(5) + ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3) +} + )" }, // batch norm training diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 33ce7e23a82d840676bba5f1ca9c0ffc4433465d..ae8c08cf1d16ad6738962f3be7c1b5512110b1d1 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -89,7 +89,7 @@ std::vector HloPassPipeline::GetEnabledPasses( std::vector enabled_passes; for (auto& pass : passes_) { - if (disabled_pass_names.count(string(pass->name())) == 0) { + if (!disabled_pass_names.contains(pass->name())) { enabled_passes.push_back(pass.get()); } } diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc index 5eb707a957e49d86cdb2f72b72ce750bf29b8fd2..9cc202aa9f5fe5a20a9da05251ea811137ccaadb 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.cc +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" @@ -34,11 +35,10 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, for (const HloComputationInfo& computation_info : hlo_profile_printer_data.computation_infos()) { const auto& instruction_infos = computation_info.instruction_infos(); - bool any_instruction_profiled = - std::any_of(instruction_infos.begin(), instruction_infos.end(), - [&](const HloInstructionInfo& instruction_info) { - return counters[instruction_info.profile_index()] != 0; - }); + bool any_instruction_profiled = absl::c_any_of( + instruction_infos, [&](const HloInstructionInfo& instruction_info) { + return counters[instruction_info.profile_index()] != 0; + }); if (!any_instruction_profiled) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index edaa4c59e2674e5f165c468059747d3dd2d54218..0fced7f15bdaf1dbe349e3b0fc6ada68393c6512 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -49,7 +49,7 @@ void HloReachabilityMap::SetReachabilityToUnionHelper( absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. - if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { + if (!absl::c_linear_search(inputs, instruction)) { bit_vector->SetToZero(); } bit_vector->Set(GetIndex(instruction)); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index ac74e2432f2176e13eaf7d4a1934a50ee89d1042..9ca14ca18a1c47f3975cfdb57a03f3b6f03379df 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -235,8 +235,7 @@ class InstructionList { } // Now scan forwards until we find one of the before_instructions. - while (std::find(before_instructions.begin(), before_instructions.end(), - min_position_item) == before_instructions.end()) { + while (!absl::c_linear_search(before_instructions, min_position_item)) { min_position_item = min_position_item->next; } return InsertBefore(to_insert, min_position_item); @@ -302,7 +301,7 @@ ItemList GetUsers(const InstructionList& instruction_list, // A buffer may be used by the instruction via more than one alias. For // example, a buffer which appears in more than one element of a tuple. Item* user_item = instruction_list.GetItem(user); - if (std::find(users.begin(), users.end(), user_item) == users.end()) { + if (!absl::c_linear_search(users, user_item)) { users.push_back(user_item); } } @@ -456,8 +455,7 @@ class MemoryUsageTracker { return false; } const BufferIdList& in_progress_uses = in_progress_item_->buffers_used; - return std::find(in_progress_uses.begin(), in_progress_uses.end(), - buffer_id) != in_progress_uses.end(); + return absl::c_linear_search(in_progress_uses, buffer_id); } // Returns whether the given instruction is live at the current program @@ -535,8 +533,7 @@ MemoryUsageTracker::MemoryUsageTracker( bool unused; for (Item* user_item : GetUsers(instruction_list_, logical_buffer, points_to_analysis, &unused)) { - if (std::find(buffer->users.begin(), buffer->users.end(), - user_item) == buffer->users.end()) { + if (!absl::c_linear_search(buffer->users, user_item)) { buffer->users.push_back(user_item); buffer->unfinished_user_count++; user_item->buffers_used.push_back(buffer->id); @@ -784,8 +781,7 @@ bool MemoryUsageTracker::Check() const { for (const Buffer& buffer : buffers_) { if (buffer.defining_instruction->instruction == instruction) { - CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), - buffer.id) != defined_buffers.end()) + CHECK(absl::c_linear_search(defined_buffers, buffer.id)) << "Instruction " << instruction->name() << " defined buffers is missing: " << buffer.ToString(); } @@ -808,8 +804,7 @@ bool MemoryUsageTracker::Check() const { int64 unfinished_uses = 0; for (Item* user : buffer.users) { const BufferIdList& used_buffers = user->buffers_used; - CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != - used_buffers.end()) + CHECK(absl::c_linear_search(used_buffers, buffer.id)) << "Instruction " << user->instruction->name() << " used buffers is missing " << buffer.ToString(); if (!IsFinished(user)) { @@ -836,10 +831,10 @@ int64 RematerializationCost(const HloInstruction* instruction, // If none of the users of 'instruction' have been placed in the sequence (as // tracked by memory_tracker), then rematerialization of 'instruction' is a // zero-cost move of 'instruction' in the sequence. - if (!std::any_of(instruction->users().begin(), instruction->users().end(), - [&memory_tracker](const HloInstruction* inst) { - return memory_tracker.IsPlaced(inst); - })) { + if (!absl::c_any_of(instruction->users(), + [&memory_tracker](const HloInstruction* inst) { + return memory_tracker.IsPlaced(inst); + })) { return 0; } diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 5a9b820a9d7f58695383b21c9e2126cf98970c83..d7d66ae1c4592723ca991d5ee971fa72cc1af90a 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -383,9 +383,7 @@ ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( if (device_assignment != nullptr) { run_options.set_device_assignment(device_assignment); } - return ServiceExecutableRunOptions( - run_options, backend().StreamBorrower(), - /*xla_intra_op_thread_pool=*/backend().eigen_intra_op_thread_pool()); + return ServiceExecutableRunOptions(run_options, backend().StreamBorrower()); } Backend& HloRunner::backend() { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index 8f6eb974c5179b420c8f961393ca923e0a3b3530..e75373501cffac6a736be89e9f6139b6ff2cdbc1 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -140,7 +140,7 @@ Status HloSchedule::UpdateComputationSchedule( std::queue worklist; for (HloInstruction* instruction : computation->instructions()) { - if (ids_in_schedule.count(instruction->unique_id()) == 0) { + if (!ids_in_schedule.contains(instruction->unique_id())) { // This is a newly added instruction which is not in the schedule. if (instruction->operands().empty()) { worklist.push(instruction); @@ -204,7 +204,7 @@ Status HloSchedule::Update() { std::vector nonfusion_computations = module_->MakeNonfusionComputations(); for (const HloComputation* computation : nonfusion_computations) { - TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + TF_RET_CHECK(sequences_.contains(computation->unique_id())) << "Computation " << computation->name() << " not in HloSchedule."; } if (sequences_.size() > nonfusion_computations.size()) { @@ -215,7 +215,7 @@ Status HloSchedule::Update() { nonfusion_computations_ids.insert(computation->unique_id()); } for (auto it = sequences_.begin(); it != sequences_.end();) { - if (nonfusion_computations_ids.count(it->first) == 0) { + if (!nonfusion_computations_ids.contains(it->first)) { sequences_.erase(it++); } else { ++it; @@ -244,7 +244,7 @@ Status HloSchedule::Verify() const { << "Schedule has " << sequences_.size() << " sequences, but module has " << nonfusion_computations.size() << " non-fusion computations"; for (const HloComputation* computation : nonfusion_computations) { - TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + TF_RET_CHECK(sequences_.contains(computation->unique_id())) << "Computation " << computation->name() << " missing from HLO schedule."; } @@ -268,7 +268,7 @@ Status HloSchedule::Verify() const { << instruction_position.size() << " instructions, expected " << computation->instruction_count(); for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(instruction_position.count(instruction) == 1) + TF_RET_CHECK(instruction_position.contains(instruction)) << "Instruction " << instruction->name() << " is not in schedule"; } diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 486ddbf499de80c634bc497158cd79ca066cc866..a5f54ae2c33259d080631061dff9ae40b41495dc 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -110,7 +110,7 @@ class HloSchedule { // Returns true if the schedule has a sequence for the given computation. bool is_computation_scheduled(const HloComputation* computation) const { - return sequences_.count(computation->unique_id()) == 1; + return sequences_.contains(computation->unique_id()); } // Updates the schedule such that it is (again) a valid schedule for the diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index bdd6ec1169a7d89cb4978a624d5606a5fb89df0f..37cc146bd7a6f2aef9373bd4afd8572ffac6473c 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/overflow_util.h" @@ -106,13 +107,12 @@ string HloSharding::ToString() const { bool HloSharding::UsesDevice(int64 device) const { if (IsTuple()) { - return std::any_of( - tuple_elements_.begin(), tuple_elements_.end(), - [&](const HloSharding& s) { return s.UsesDevice(device); }); + return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) { + return s.UsesDevice(device); + }); } const auto& devices = tile_assignment_; - return replicated_ || - std::find(devices.begin(), devices.end(), device) != devices.end(); + return replicated_ || absl::c_linear_search(devices, device); } std::map HloSharding::UsedDevices(int64* count) const { @@ -316,7 +316,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, // All tile assignments must be less than the number of available cores and // unique. Status status = Status::OK(); - std::set seen_cores; + absl::flat_hash_set seen_cores; tile_assignment_.Each( [&](absl::Span indices, int32 core) { // Don't overwrite a bad status, so we report the first error. @@ -324,7 +324,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, if (core >= num_devices) { status = tensorflow::errors::InvalidArgument(StrCat( "core ", core, " > ", num_devices, " in tile assignment")); - } else if (seen_cores.count(core) != 0) { + } else if (seen_cores.contains(core)) { status = tensorflow::errors::InvalidArgument( StrCat("core ", core, " is not unique in tile assignment")); } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 9775505f8608ced3e33abe376f4922cc6a972726..5789ae09988d2a85247c5b8c037a172b3699f3b7 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -101,8 +101,8 @@ class HloSharding { if (!IsTuple()) { return replicated_; } - return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), - [](const HloSharding& s) { return s.IsReplicated(); }); + return absl::c_all_of( + tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); }); } // Returns true if the tile size is the same as the input size. @@ -110,8 +110,9 @@ class HloSharding { if (!IsTuple()) { return maximal_; } - return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), - [](const HloSharding& s) { return s.IsTileMaximal(); }); + return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { + return s.IsTileMaximal(); + }); } // Returns true if the sharding defines an operation on the given device. diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index b414d2a66328284a4d8be8d35206bb837a2b3a58..094d98bc6e54028557f6d38cd165bf34e1fb8c46 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -99,7 +99,7 @@ std::vector LocatePassThroughDomainLinks( << "Instruction is not a kDomain: " << instruction->ToString(); for (HloInstruction* user : instruction->users()) { if (user->opcode() == HloOpcode::kDomain && - domain.exit_domains.count(user) != 0) { + domain.exit_domains.contains(user)) { pass_through.emplace_back(user, instruction); VLOG(2) << "Found passthrough domain link:"; VLOG(2) << " " << user->ToString(); @@ -253,7 +253,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); for (HloInstruction* user : instruction->users()) { if (user->opcode() == HloOpcode::kDomain && - domain.exit_domains.count(user) > 0) { + domain.exit_domains.contains(user)) { // If a user is a domain and it is registered in the domain exits, then // the instruction sharding is taken directly from the domain, and no // further users need to be visited. diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 4ea39a7628317867eb054bb14de48f721354cd9d..c1f69db74eafb7743e85f499f2f4828ed0375501 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -61,8 +61,7 @@ void CleanNodeName(string* name) { name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); const string chars_to_replace = "<>[]"; auto pred = [&](char c) { - return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) != - chars_to_replace.end(); + return absl::c_linear_search(chars_to_replace, c); }; std::replace_if(name->begin(), name->end(), pred, '_'); } diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index d409df06be44a79d1bf7c90ce3ff34fca975fda4..218b33b2ac2b86edc30b2f014ba206c71da37682 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -209,7 +209,7 @@ std::ostream& operator<<(std::ostream& out, const HloValue& value) { } void HloValueSet::SortAndUniquifyValues() { - std::sort(values_.begin(), values_.end(), HloValue::IdLessThan); + absl::c_sort(values_, HloValue::IdLessThan); values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual), values_.end()); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 8d8720d7be58b7f02dbeba7b2394956b6e93a43f..2c69c27ce55e0f9f6d185f807e6d43d6f2cfe8d4 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -387,6 +387,14 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { TF_RETURN_IF_ERROR(CheckOperandCount(bitcast, 1)); + // Bitcasts are not allowed to change the element type. + if (bitcast->operand(0)->shape().element_type() != + bitcast->shape().element_type()) { + return InternalError( + "Bitcast can not change the element type from %s to %s", + PrimitiveType_Name(bitcast->operand(0)->shape().element_type()), + PrimitiveType_Name(bitcast->shape().element_type())); + } return Status::OK(); } @@ -496,21 +504,23 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) { } Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { - TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2)); - return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( - dynamic_slice->operand(0)->shape(), - dynamic_slice->operand(1)->shape(), - dynamic_slice->dynamic_slice_sizes())); + return CheckShape( + dynamic_slice, + ShapeInference::InferDynamicSliceShape( + dynamic_slice->operand(0)->shape(), + Cast(dynamic_slice)->index_shapes(), + dynamic_slice->dynamic_slice_sizes())); } Status ShapeVerifier::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { - TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3)); - return CheckShape(dynamic_update_slice, - ShapeInference::InferDynamicUpdateSliceShape( - dynamic_update_slice->operand(0)->shape(), - dynamic_update_slice->operand(1)->shape(), - dynamic_update_slice->operand(2)->shape())); + return CheckShape( + dynamic_update_slice, + ShapeInference::InferDynamicUpdateSliceShape( + dynamic_update_slice->operand(0)->shape(), + dynamic_update_slice->operand(1)->shape(), + Cast(dynamic_update_slice) + ->index_shapes())); } Status ShapeVerifier::HandleTuple(HloInstruction* tuple) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index a1a6aba9728c137d17487b5914f67cb3966fc12b..479905b317d5639ff2cebc4d1044e21b527693f6 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -168,8 +168,13 @@ class ShapeVerifier : public DfsHloVisitor { // An interface used to encapsulate target-specific verification quirks. class TargetVerifierMetadata { public: + TargetVerifierMetadata(std::function shape_size_function) + : shape_size_function_(shape_size_function) {} + // Returns a target-specific shape size. - virtual int64 ShapeSize(const Shape& shape) const = 0; + int64 ShapeSize(const Shape& shape) const { + return shape_size_function_(shape); + } virtual std::unique_ptr GetVerifier() const = 0; @@ -178,20 +183,23 @@ class TargetVerifierMetadata { TargetVerifierMetadata(const TargetVerifierMetadata&) = delete; TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete; + + private: + // Returns a target-specific shape size. + std::function shape_size_function_; }; // The default implementation of TargetVerifierMetadata, used unless the target // needs to override it. class DefaultVerifierMetadata : public TargetVerifierMetadata { public: - DefaultVerifierMetadata(bool layout_sensitive, bool allow_mixed_precision) - : layout_sensitive_(layout_sensitive), + DefaultVerifierMetadata( + bool layout_sensitive, bool allow_mixed_precision, + std::function shape_size_function) + : TargetVerifierMetadata(shape_size_function), + layout_sensitive_(layout_sensitive), allow_mixed_precision_(allow_mixed_precision) {} - int64 ShapeSize(const Shape& shape) const override { - return ShapeUtil::ByteSizeOf(shape); - } - // Creates a ShapeVerifier that checks that shapes match inferred // expectations. This creates a new verifier every time because ShapeVerifier, // being a DfsHloVisitor, is stateful. We want a clean object for each run of @@ -210,11 +218,14 @@ class DefaultVerifierMetadata : public TargetVerifierMetadata { // the module. class HloVerifier : public HloModulePass { public: - explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision, - std::function - instruction_can_change_layout_func = {}) + explicit HloVerifier( + bool layout_sensitive, bool allow_mixed_precision, + std::function + instruction_can_change_layout_func = {}, + std::function shape_size_func = + [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); }) : target_metadata_(absl::make_unique( - layout_sensitive, allow_mixed_precision)), + layout_sensitive, allow_mixed_precision, shape_size_func)), instruction_can_change_layout_func_( std::move(instruction_can_change_layout_func)) { CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 4bc557e4e62e7df4e25fda86fe417e84129b464c..d27c62a879bedf316508b6ff95ab2536106b40d0 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -27,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -386,6 +388,55 @@ TEST_F(HloVerifierTest, AddWithLayoutChange) { ASSERT_TRUE(status.ok()); } +TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) { + const char* const kScalarIndexDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] { + %original_parameter = s32[2,2,258] parameter(0) + %constant = s32[] constant(0) + %start_index = s32[] parameter(1) + ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kScalarIndexDynamicSlice, config)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) { + const char* const kScalarIndexDynamicSlice = R"( + HloModule DynamicUpdateSlice_module + + ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] { + %input = s32[1,1,25,1]{3,2,1,0} parameter(0) + %update = s32[1,1,2,1]{3,2,1,0} parameter(1) + %start_index.0 = s32[] parameter(2) + %start_index.1 = s32[] parameter(3) + %start_index.2 = s32[] parameter(4) + %start_index.3 = s32[] parameter(5) + ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3) + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kScalarIndexDynamicSlice, config)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); auto status = verifier().Run(module.get()).status(); @@ -399,8 +450,9 @@ TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) { HloModule SliceWithLayoutChange ENTRY SliceWithLayoutChange { par0 = f32[4,5]{0,1} parameter(0) - par1 = s32[2] parameter(1) - ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1), + par1 = s32[] parameter(1) + par2 = s32[] parameter(2) + ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2), dynamic_slice_sizes={3,4} } )"; @@ -429,5 +481,23 @@ TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { EXPECT_THAT(status.error_message(), HasSubstr("Instruction shouldn't change layouts")); } + +TEST_F(HloVerifierTest, BitcastCanNotChangeElementType) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY BitcastCanNotChangeElementType { + constant.0 = f32[2] constant({0.0, 0.0}) + ROOT bitcast = s32[2] bitcast(constant.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Bitcast can not change the element type")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index 90904ac00110457bcc3b8974816a7080c4ab89fc..88fc62bd1e2a7830b3f61738a8642308ef4225a7 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -128,9 +128,9 @@ string HumanReadableProfileBuilder::ToString() const { // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); - std::sort( - sorted_ops.begin(), sorted_ops.end(), - [](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; }); + absl::c_sort(sorted_ops, [](const OpInfo& a, const OpInfo& b) { + return a.cycles > b.cycles; + }); for (const auto& op : sorted_ops) { print_op(op); } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index a41cf714c5ed6adcf7aa1f0e54cf052594834b5d..76bf48870d55e82497ba5f63e9e2e2a322cb330e 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -103,7 +103,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( do { const HloInstruction* instr = stack.back(); - if (cache_.count(instr)) { + if (cache_.contains(instr)) { stack.pop_back(); continue; } @@ -111,9 +111,9 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( switch (FindOrDie(dfs_state_map, instr)) { case kDiscovered: { for (const HloInstruction* operand : instr->operands()) { - if (!cache_.count(operand)) { + if (!cache_.contains(operand)) { stack.push_back(operand); - CHECK(!dfs_state_map.count(operand) || + CHECK(!dfs_state_map.contains(operand) || dfs_state_map[operand] == kDiscovered); dfs_state_map[operand] = kDiscovered; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index f6f25c44131226d35f8e927d62defee291bf46dd..b97060535d998e174639dceca5cde517cef01e30 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -178,19 +178,18 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape)); } }); - return std::count_if(hlo->operands().begin(), hlo->operands().end(), - [output_rank](HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kBroadcast || - operand->opcode() == HloOpcode::kIota) { - return false; - } - if (operand->opcode() == HloOpcode::kConstant && - ShapeUtil::IsEffectiveScalar(operand->shape())) { - return false; - } - return ShapeUtil::TrueRank(operand->shape()) >= - output_rank; - }) <= 1; + return absl::c_count_if( + hlo->operands(), [output_rank](HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kIota) { + return false; + } + if (operand->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(operand->shape())) { + return false; + } + return ShapeUtil::TrueRank(operand->shape()) >= output_rank; + }) <= 1; } bool InstructionFusion::CanFuseOnAllPaths( @@ -409,9 +408,8 @@ class ReversePostOrderFusionQueue : public FusionQueue { } sorted_operand_numbers.push_back(i); } - std::sort( - sorted_operand_numbers.begin(), sorted_operand_numbers.end(), - [&](int64 i, int64 j) { + absl::c_sort( + sorted_operand_numbers, [&](int64 i, int64 j) { // Instructions with higher priority in the queue come first. return ( FindOrDie(post_order_index_, instruction->mutable_operand(i)) > diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index a981d94a999e3d322986bc2bfd56a5b0b5d175fc..9b4cafa4b42213d589ff25ce73d0fdfecc678dc2 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -1,12 +1,12 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) - load( "//tensorflow/core:platform/default/build_config_root.bzl", "if_static", ) +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + cc_library( name = "interpreter_transfer_manager", srcs = ["interpreter_transfer_manager.cc"], @@ -32,8 +32,10 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -41,12 +43,14 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/compiler/xla/service:map_inliner", + "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", @@ -115,6 +119,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/stream_executor/host:host_stream", + "//tensorflow/stream_executor/host:host_timer", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index d37ae94bf6c4c697bbf30390c02a5099271e00a4..4818b2dae0a9951346600a9b2906488c3ef7e06e 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/map_inliner.h" +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -43,9 +45,15 @@ namespace interpreter { Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); + pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), LayoutAssignment::InstructionCanChangeLayout); + + ReducePrecisionInsertion::AddPasses( + &pipeline, hlo_module->config().debug_options(), + ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); + return pipeline.Run(hlo_module).status(); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 88e33464a27137fd56021e232cc93a832d1a5b3f..7a6ebdef708bcc3a92fbd8618db0c42c35e6ce8b 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -68,6 +68,18 @@ StatusOr InterpreterExecutable::ExecuteOnStream( "Mismatch between argument count and graph parameter count."); } + // Check that the args have the right shape. + for (int64 i = 0; i < computation->num_parameters(); ++i) { + const auto& expected_shape = computation->parameter_instruction(i)->shape(); + const auto& actual_shape = arguments[i]->on_device_shape(); + if (!ShapeUtil::Equal(expected_shape, actual_shape)) { + return InvalidArgument( + "Shape mismatch on parameter %d. Expected %s, but was %s.", i, + ShapeUtil::HumanString(expected_shape), + ShapeUtil::HumanString(actual_shape)); + } + } + TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, TransferManager::GetForPlatform(platform)); @@ -86,8 +98,8 @@ StatusOr InterpreterExecutable::ExecuteOnStream( { tensorflow::mutex_lock lock(evaluator_lock_); evaluator_->ResetVisitStates(); - TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( - *computation, arg_literals)); + TF_ASSIGN_OR_RETURN(result_literal, + evaluator_->Evaluate(*computation, arg_literals)); } // Transform the result literal back into a ShapedBuffer. diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 1c4ac178f0b9a2789d4e6d4554e3a56e0e7cf77f..10ff7bb6d46ee3b2cd1228b4b7a49269be8c65d3 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -147,12 +147,9 @@ bool LayoutConstraints::OperandBufferForwarded( PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction); PointsToSet::BufferSet* operand_buffers = GetBufferSet(instruction->operand(operand_no)); - for (const LogicalBuffer* output_buffer : *output_buffers) { - if (operand_buffers->count(output_buffer) > 0) { - return true; - } - } - return false; + return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) { + return operand_buffers->count(b) > 0; + }); } Status LayoutConstraints::SetBufferLayout(const Layout& layout, @@ -1236,6 +1233,23 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( }); } +namespace { +// A transpose or a reshape that only changes trivial dimensions have meaningful +// layouts that are valuable to propagate in a depthfirst manner to avoid +// unassigned layouts in the graph. +bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) { + switch (hlo.opcode()) { + case HloOpcode::kReshape: + return std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()); + case HloOpcode::kTranspose: + return true; + default: + return false; + } +} + +} // namespace + Status LayoutAssignment::PropagateOperandConstraint( const OperandLayoutConstraint& operand_constraint, LayoutConstraints* constraints) { @@ -1370,7 +1384,7 @@ Status LayoutAssignment::PropagateOperandConstraint( TF_RETURN_IF_ERROR(constraints->SetBufferLayout( *layout, *buffer, /*mandatory=*/user->opcode() == HloOpcode::kReduce, - /*dfs=*/false)); + /*dfs=*/InstructionShouldPropagateDepthFirst(*user))); } } return Status::OK(); @@ -1420,11 +1434,9 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(), instruction, operand_no); if (operand_layout != nullptr) { - // Do not propagate operand constraints of transposes and reshapes, it - // tends to create really bad layouts. TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( *operand_layout, instruction, operand_no, /*mandatory=*/false, - /*dfs=*/false)); + /*dfs=*/InstructionShouldPropagateDepthFirst(*instruction))); } } else { VLOG(6) << "Operand already has a constraint " @@ -2120,7 +2132,7 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kCopy && - added_copies_.count(instruction) > 0) { + added_copies_.contains(instruction)) { VLOG(5) << "Removing added copy: " << instruction->ToString(); TF_RETURN_IF_ERROR( instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 3b081de3c7826c3c11a7d87d542835d0ecce1b7e..5701cb5b025e563247d46d0d24f81a5f886fc23b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -243,7 +243,7 @@ class ChannelLayoutConstraints { // Returns true if channel_id has a layout constraint. bool IsChannelConstrained(int64 channel_id) const { - return constraints_.count(channel_id) > 0; + return constraints_.contains(channel_id); } // Given `shape`, apply the layout for `channel_id`. `channel_id` must already @@ -276,7 +276,7 @@ class ChannelLayoutConstraints { } private: - std::unordered_map constraints_; + absl::flat_hash_map constraints_; }; // HLO pass which assigns layouts to all instructions in the HLO module while diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 387b385157ab6ece65c692de65b4da33038f1f30..c8cf3c47d380012fdb0206c0d20d67e6a13017ae 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -960,8 +960,9 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange { par0 = f32[3,4]{1,0} parameter(0) par1 = f32[4,5]{0,1} parameter(1) - par2 = s32[2] parameter(2) - dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4} + par2 = s32[] parameter(2) + par3 = s32[] parameter(3) + dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4} ROOT add0 = f32[3,4]{1,0} add(par0,dslice0) } )"; @@ -982,7 +983,7 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { m::Parameter(), m::DynamicSlice( m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy), - m::Parameter(2))))); + m::Parameter(2), m::Parameter(3))))); } TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 728a66b388f0f9af480ff88b5e96990a26e36af5..54e8fe1947eea43b89b6fae70655a4f52ea8e69d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -169,6 +169,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm//:core", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 643ecd0fbaa546c551097b29e74ccd49418e1466..ce3d922ca7a9bdea3a520959a8b8d284bc3e0d64 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -81,9 +81,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, if (hlo.opcode() == HloOpcode::kParameter) { const std::vector& parameter_instructions = module_.entry_computation()->parameter_instructions(); - if (std::find(parameter_instructions.begin(), - parameter_instructions.end(), - &hlo) != parameter_instructions.end()) { + if (absl::c_linear_search(parameter_instructions, &hlo)) { array->MarkInvariantOverWholeProgram(context_); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 2b46b3c3964b15548dbacc8b0ada0047a0fa85b6..12e2f449e23ac2511aac576fed893f5a9ef510c0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -76,15 +76,12 @@ class AliasAnalysis { // A map from a buffer slice to metadata corresponding to its alias.scope // metadata. The index kParameterAliasSet is used to hold aliasing // information for parameters. - absl::flat_hash_map + absl::flat_hash_map alias_scope_metadata_; // A map from a buffer slice to metadata corresponding to its noalias // metadata. - absl::flat_hash_map - noalias_metadata_; + absl::flat_hash_map noalias_metadata_; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 1da77945328c42826c305b778b4964f53dece90d..c66eaec8fb0e4c03f6967fec0cf0ae9661cdf470 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -36,8 +36,10 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, // EmitFusedDynamicUpdateSliceInPlace. // // Emits a sequential loop if launch_dimensions is null. +using IndexGenerator = std::function(int64)>; + static Status EmitDynamicUpdateSliceInPlaceImpl( - const Shape& update_shape, const ElementGenerator& start_indices_generator, + const Shape& update_shape, const IndexGenerator& start_indices_generator, bool is_signed, ElementGenerator update_array_generator, const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, absl::string_view name, llvm::IRBuilder<>* b) { @@ -47,8 +49,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( const int64 rank = output_shape.rank(); IrArray::Index start_index(b->getInt64Ty(), rank); for (int64 i = 0; i < rank; ++i) { - IrArray::Index dim_index({b->getInt64(i)}); - TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); + TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(i)); llvm::Value* output_dim_size = llvm::ConstantInt::get( start_index[i]->getType(), output_shape.dimensions(i)); llvm::Value* update_dim_size = llvm::ConstantInt::get( @@ -112,9 +113,20 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, Shape output_shape = output_array.GetShape(); Shape update_shape = update_array.GetShape(); - ElementGenerator start_indices_generator = [&](const IrArray::Index& index) { - return start_indices_array.EmitReadArrayElement(index, b); - }; + IndexGenerator start_indices_generator; + // TODO(b/118437727): Remove the R1 path, and rename the variables. + if (start_indices_array.GetShape().rank() == 1) { + start_indices_generator = [&](int64 index) { + return start_indices_array.EmitReadArrayElement( + IrArray::Index({b->getInt64(index)}), b); + }; + } else { + start_indices_generator = [&](int64 index) { + return operand_arrays[2 + index].EmitReadArrayElement( + IrArray::Index(b->getInt64Ty()), b); + }; + } + ElementGenerator update_array_generator = [&](const IrArray::Index& index) { return update_array.EmitReadArrayElement(index, b); }; @@ -165,8 +177,21 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( elemental_emitter); TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); - ElementGenerator start_indices_generator = - fused_emitter.GetGenerator(start_indices); + + // TODO(b/118437727): Remove the R1 path, and rename the variables. + IndexGenerator start_indices_generator; + if (start_indices->shape().rank() == 1) { + start_indices_generator = [&](int64 index) { + return fused_emitter.GetGenerator(start_indices)( + IrArray::Index({b->getInt64(index)})); + }; + } else { + start_indices_generator = [&](int64 index) { + ElementGenerator element_generator = + fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index)); + return element_generator(IrArray::Index(b->getInt64Ty())); + }; + } bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape()); return EmitDynamicUpdateSliceInPlaceImpl( diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 03a475b40b5a1f5100091fc11d744792b641cb04..e440f05e2b2f0d4a2a4c7b326b4881183de4d235 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -35,7 +35,7 @@ using llvm_ir::IrArray; Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { indexed_generators_[hlo] = [=](const IrArray::Index& index) -> StatusOr { - if (generated_value_cache_[hlo].count(index.multidim()) > 0) { + if (generated_value_cache_[hlo].contains(index.multidim())) { llvm::Value* generated_value = generated_value_cache_[hlo][index.multidim()]; llvm::BasicBlock* generated_value_bb = nullptr; diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 1b9c61f6700e2a1309b21e499f4a9e2439ed3702..e6d52a580c04a920d3f0e8ed6f39c1cae587cf1b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" @@ -134,8 +135,9 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { // Cache of generated values, lest we regenerate an element of a node with // multiple outgoing edges - std::unordered_map, llvm::Value*>> + absl::flat_hash_map< + const HloInstruction*, + absl::flat_hash_map, llvm::Value*>> generated_value_cache_; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index d6d84994ee147f4b8c1a333b0eaccdf6e0a2219b..a483f7051f268edf621a861ee09fb8c2376a8fc6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -189,6 +189,8 @@ class IrArray { return llvm::ConstantInt::get(index_type_, c); } + void ClearLinearIndex() { linear_ = nullptr; } + private: // Changing the multi-dimensional index invalidates the linear index. std::vector& mutable_multidim() { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index cebbc4290163d4e98003cd7b5df6ec906509a446..cd8dd72cd775d5e0b52f96a2326367da0775e7eb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -123,7 +123,8 @@ KernelMappingScheme::KernelMappingScheme( dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()), tile_sizes_{1, tile_size_y, tile_size_x}, num_threads_x_(num_threads_x), - num_threads_y_(num_threads_y) { + num_threads_y_(num_threads_y), + dilated_x_(true) { DCHECK_EQ(dims_in_elems_.size(), 3); DCHECK_EQ(req_block_sizes.size(), 3); diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index fb633b12e60d1a9f3103fb2919ad2c3f3f14de20..f802cc27d519e621262f328903697373aa8c284c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -117,7 +117,10 @@ class KernelMappingScheme { int64 GetNumberOfTilesInOneBlock() const { return absl::c_accumulate(block_sizes_, 1, std::multiplies()); } - + int64 GetNumberOfTilesInOneBlockForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return block_sizes_[d]; + } int64 GetNumberOfBlocks() const { return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies()); } @@ -147,6 +150,16 @@ class KernelMappingScheme { GetNumberOfThreadsForDimensionY(); } + bool DilatedX() const { return dilated_x_; } + void SetDilatedX(bool v) { + dilated_x_ = v; + if (!dilated_x_) { + // dilated_x_=false is for the purpose of vectorization, which requires + // GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_. + CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0); + } + } + IrArray::Index EmitBlockIndex(llvm::Type* index_ty); // Returns the index for the first tile in the block with the given block // index. @@ -186,6 +199,13 @@ class KernelMappingScheme { int64 num_threads_x_; // Number of threads used to process elements in the Y direction of a tile. int64 num_threads_y_; + + // When num_threads_x threads process a total of tile_size_x elements in the + // X dimension of a tile, each threads process n=tile_size_x/num_threads_x + // elements. When dilated_x=false, the n elements processed by a thread are + // contiguous. On the other hand, when dilated_x=true the n elements are + // dilated by a factor of num_threads_x. + bool dilated_x_; }; // A class to represent information for tiled parameters to support IR emission diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 9ccdd7d8d818b9fa3aa77cdd10d37ca18928b448..53d52d9a3d918fa6dee093668923fcfff963d084 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -198,7 +198,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) { continue; } - if (in_list.count(instr) > 0) { + if (in_list.contains(instr)) { continue; } int64 profit = GetProfit(instr, fusion); diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index daa718879ddd45afb02725b557380b2f49fe833e..f6feed29935a1446499559d947dff0a8eefe5d2e 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -34,7 +34,7 @@ bool IsAllowed(char character) { } // namespace NameUniquer::NameUniquer(const string& separator) { - CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed)) + CHECK(absl::c_all_of(separator, IsAllowed)) << "separator should comprises allowed characters only"; separator_ = separator; } diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index c362a60d949a5b823b27b23cb0cb5f365e318cb3..9e3d1060210790f60243195a1c1dff13f1fc7fc5 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -1878,7 +1878,7 @@ class HloInstructionPattern { // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) + constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const -> decltype(this->WithShape(Shape().EqualTo(shape))) { return WithShape(Shape().EqualTo(shape)); } @@ -1886,7 +1886,7 @@ class HloInstructionPattern { // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) + constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const -> decltype(this->WithShape(Shape().CompatibleTo(shape))) { return WithShape(Shape().CompatibleTo(shape)); } @@ -2057,7 +2057,6 @@ XLA_UNOP_PATTERN(SendDone) XLA_UNOP_PATTERN(Sign) XLA_UNOP_PATTERN(Sin) XLA_UNOP_PATTERN(Slice) -XLA_UNOP_PATTERN(Sort) XLA_UNOP_PATTERN(Tanh) XLA_UNOP_PATTERN(Transpose) #undef XLA_UNOP_PATTERN @@ -2119,7 +2118,6 @@ XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) XLA_BINOP_PATTERN(Convolution) XLA_BINOP_PATTERN(Dot) -XLA_BINOP_PATTERN(DynamicSlice) XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) @@ -2236,8 +2234,10 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, XLA_VARIADIC_OP_PATTERN(AfterAll); XLA_VARIADIC_OP_PATTERN(Concatenate); XLA_VARIADIC_OP_PATTERN(CustomCall); +XLA_VARIADIC_OP_PATTERN(DynamicSlice) XLA_VARIADIC_OP_PATTERN(Map) XLA_VARIADIC_OP_PATTERN(Reduce); +XLA_VARIADIC_OP_PATTERN(Sort); XLA_VARIADIC_OP_PATTERN(Tuple); // Helpers for matching non-constant instructions. diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 896b73cda41cb21b539b586aa4701c5bad43f8b9..0491f641fc7835c404612d6f8dcda83a02ca97d6 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -260,8 +260,8 @@ PlatformUtil::GetStreamExecutors( // Block here in thread_pool destructor until all devices are initialized. } VLOG(1) << "Device initialization complete"; - if (std::all_of(stream_executors.begin(), stream_executors.end(), - [](se::StreamExecutor* s) { return s == nullptr; })) { + if (absl::c_all_of(stream_executors, + [](se::StreamExecutor* s) { return s == nullptr; })) { return InternalError("no supported devices found for platform %s", platform->Name()); } diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index e8496dbd72bbe0a2e78067dda209cdb4b620e826..036c3c36f648daf8963a6b25e300b93c1bdf78d9 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -26,7 +26,6 @@ limitations under the License. namespace xla { - // Transposes the given scatter_indices such that the index_vector_dim becomes // the most-minor dimension. static StatusOr TransposeIndexVectorDimToLast( @@ -60,6 +59,13 @@ static StatusOr CanonicalizeScatterIndices( TF_ASSIGN_OR_RETURN( HloInstruction * transposed_scatter_indices, TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + if (scatter_indices->shape().rank() == index_vector_dim + 1 && + scatter_indices->shape().dimensions(index_vector_dim) == 1) { + auto new_shape = + ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape()); + TF_ASSIGN_OR_RETURN(scatter_indices, + MakeReshapeHlo(new_shape, scatter_indices)); + } bool indices_are_scalar = index_vector_dim == scatter_indices->shape().dimensions_size(); @@ -214,9 +220,6 @@ static StatusOr> ScatterLoopBody( HloInstruction* updates = loop_state[2]; bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1; - CHECK_EQ(has_scalar_indices, - dim_numbers.index_vector_dim() == - scatter->operand(1)->shape().dimensions_size()); // Build a vector form of the induction variable of the while loop. TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index a0126f39b3dc4281abedc36a19dd20c3b128e249..2732d498d80121fcbff037d4e3bcd226c61cae2f 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" @@ -552,9 +553,7 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); - run_options.emplace_back( - options, backend->StreamBorrower(), - /*xla_intra_op_thread_pool=*/backend->eigen_intra_op_thread_pool()); + run_options.emplace_back(options, backend->StreamBorrower()); } if (options_.number_of_replicas() == 1) { @@ -1097,9 +1096,12 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, TF_ASSIGN_OR_RETURN(std::unique_ptr module, CreateModuleFromProto(arg->computation(), config)); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(module.get())); + HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( - *module, /*arg_literals=*/{})); + evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference); + TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {})); // Since the result layout is non-effective to the Evaluator results, explicit // relayout here. diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index dbfed628bfcabffe66bef41a82e0e2430897d80d..6bee671056552b83014367889320b748659bbfdf 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -32,12 +32,10 @@ class ServiceExecutableRunOptions { ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} - explicit ServiceExecutableRunOptions( - ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr, - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr) + explicit ServiceExecutableRunOptions(ExecutableRunOptions run_options, + StreamBorrower borrow_stream = nullptr) : run_options_(std::move(run_options)), - borrow_stream_(std::move(borrow_stream)), - xla_intra_op_thread_pool_(xla_intra_op_thread_pool) {} + borrow_stream_(std::move(borrow_stream)) {} // Returns reference or pointer to `ExecutableRunOptions` member. const ExecutableRunOptions& run_options() const { return run_options_; } @@ -56,15 +54,9 @@ class ServiceExecutableRunOptions { : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); } - // Returns reference to thread pool for execution of XLA ops on CPU backend. - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool() const { - return xla_intra_op_thread_pool_; - } - private: ExecutableRunOptions run_options_; StreamBorrower borrow_stream_; - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index b0e241d216dc5c6c3d6cbde2c12350f0647dd8ef..1d3f84af955df0736d8bbd87fe71152d8631bfee 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -534,9 +534,8 @@ Status ValidateDotDimensionNumbers( absl::Span contracting_dims, absl::Span batch_dims) -> bool { auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; - return std::all_of(contracting_dims.begin(), contracting_dims.end(), - in_range) && - std::all_of(batch_dims.begin(), batch_dims.end(), in_range); + return absl::c_all_of(contracting_dims, in_range) && + absl::c_all_of(batch_dims, in_range); }; absl::Span lhs_contracting_dimensions = @@ -563,9 +562,8 @@ Status ValidateDotDimensionNumbers( auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; }; - return std::all_of(contracting_dims.begin(), contracting_dims.end(), - is_unique) && - std::all_of(batch_dims.begin(), batch_dims.end(), is_unique); + return absl::c_all_of(contracting_dims, is_unique) && + absl::c_all_of(batch_dims, is_unique); }; if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || @@ -1589,29 +1587,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, input_dnums[1] = dnums.input_feature_dimension(); std::copy(dnums.input_spatial_dimensions().begin(), dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2); - std::sort(input_dnums.begin(), input_dnums.end()); + absl::c_sort(input_dnums); std::vector window_dnums(num_dims); window_dnums[0] = dnums.kernel_input_feature_dimension(); window_dnums[1] = dnums.kernel_output_feature_dimension(); std::copy(dnums.kernel_spatial_dimensions().begin(), dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2); - std::sort(window_dnums.begin(), window_dnums.end()); + absl::c_sort(window_dnums); std::vector output_dnums(num_dims); output_dnums[0] = dnums.output_batch_dimension(); output_dnums[1] = dnums.output_feature_dimension(); std::copy(dnums.output_spatial_dimensions().begin(), dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2); - std::sort(output_dnums.begin(), output_dnums.end()); + absl::c_sort(output_dnums); std::vector expected_dnums(num_dims); std::iota(expected_dnums.begin(), expected_dnums.end(), 0); const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; }; - if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) || - !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) || - !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { + if (!absl::c_all_of(input_dnums, in_range) || + !absl::c_all_of(window_dnums, in_range) || + !absl::c_all_of(output_dnums, in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s.", dnums.DebugString()); @@ -2087,35 +2085,81 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferDynamicSliceShape( - const Shape& operand_shape, const Shape& start_indices_shape, - absl::Span slice_sizes) { + const Shape& operand_shape, absl::Span start_index_shapes, + absl::Span slice_sizes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); - TF_RETURN_IF_ERROR( - ExpectArray(start_indices_shape, "start indices of dynamic slice")); + auto number_of_indices = start_index_shapes.size(); + // TODO(b/118437727): Remove this path. + if (!allow_scalar_indices || + (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) { + if (number_of_indices != 1) { + return InvalidArgument( + "Dynamic slice should have exactly 1 index operand, has %d.", + number_of_indices); + } - VLOG(2) << StrFormat( - "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", - ShapeUtil::HumanString(operand_shape), - ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", ")); + const Shape& start_indices_shape = start_index_shapes[0]; + VLOG(2) << StrFormat( + "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + StrJoin(slice_sizes, ", ")); - if (start_indices_shape.rank() != 1) { - return InvalidArgument( - "Dynamic slice start indices of rank %d must be rank1.", - start_indices_shape.rank()); - } + TF_RETURN_IF_ERROR( + ExpectArray(start_indices_shape, "start indices of dynamic slice")); - if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { - return InvalidArgument( - "Dynamic slice start indices must be of integral type."); - } + if (start_indices_shape.rank() != 1) { + return InvalidArgument( + "Dynamic slice start indices of rank %d must be rank1.", + start_indices_shape.rank()); + } - const int64 start_num_dims = start_indices_shape.dimensions(0); - if (operand_shape.rank() != start_num_dims) { - return InvalidArgument( - "Dynamic slice start number of dimensions %d (%s) must match rank " - "%d of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape), - operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { + return InvalidArgument( + "Dynamic slice start indices must be of integral type."); + } + + const int64 start_num_dims = start_indices_shape.dimensions(0); + if (operand_shape.rank() != start_num_dims) { + return InvalidArgument( + "Dynamic slice start number of dimensions %d (%s) must match rank " + "%d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); + } + } else { + VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}", + ShapeUtil::HumanString(operand_shape), + StrJoin(slice_sizes, ", ")); + + if (operand_shape.rank() != number_of_indices) { + return InvalidArgument( + "Dynamic slice start number of dimensions %d must match rank " + "%d of slice input (%s).", + number_of_indices, operand_shape.rank(), + ShapeUtil::HumanString(operand_shape)); + } + + if (number_of_indices > 0) { + const Shape& first_index_shape = start_index_shapes[0]; + if (!ShapeUtil::IsScalar(first_index_shape)) { + return InvalidArgument("Dynamic slice indices must be scalar, not %s.", + ShapeUtil::HumanString(first_index_shape)); + } + if (!ShapeUtil::ElementIsIntegral(first_index_shape)) { + return InvalidArgument( + "Dynamic slice start indices must be of integral type."); + } + for (const Shape& index_shape : start_index_shapes) { + if (!ShapeUtil::Compatible(first_index_shape, index_shape)) { + return InvalidArgument( + "Dynamic slice start indices must all have the same shape, got " + "mismatching indices with shapes %s and %s.", + ShapeUtil::HumanString(first_index_shape), + ShapeUtil::HumanString(index_shape)); + } + } + } } if (slice_sizes.size() != operand_shape.rank()) { @@ -2144,39 +2188,85 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, - const Shape& start_indices_shape) { + absl::Span start_index_shapes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR( ExpectArray(operand_shape, "operand of dynamic update slice")); TF_RETURN_IF_ERROR( ExpectArray(update_shape, "update of dynamic update slice")); - TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, - "start indices of dynamic update slice")); - VLOG(2) << StrFormat( - "updating slice of shape %s at dynamic start_indices %s with update " - "shape %s", - ShapeUtil::HumanString(operand_shape), - ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::HumanString(update_shape)); + auto number_of_indices = start_index_shapes.size(); + // TODO(b/118437727): Remove this path. + if (!allow_scalar_indices || + (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) { + if (number_of_indices != 1) { + return InvalidArgument( + "Dynamic update slice should have exactly 1 index operand, has %d.", + number_of_indices); + } + const Shape& start_indices_shape = start_index_shapes[0]; + TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, + "start indices of dynamic update slice")); - if (start_indices_shape.rank() != 1) { - return InvalidArgument( - "Dynamic update slice start indices of rank %d must be rank1.", - start_indices_shape.rank()); - } + VLOG(2) << StrFormat( + "updating slice of shape %s at dynamic start_indices %s with update " + "shape %s", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::HumanString(update_shape)); - if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { - return InvalidArgument( - "Dynamic update slice start indices must be of integral type."); - } + if (start_indices_shape.rank() != 1) { + return InvalidArgument( + "Dynamic update slice start indices of rank %d must be rank1.", + start_indices_shape.rank()); + } - const int64 start_num_dims = start_indices_shape.dimensions(0); - if (operand_shape.rank() != start_num_dims) { - return InvalidArgument( - "Dynamic update slice start number of dimensions %d (%s) must match " - "rank %d of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape), - operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must be of integral type."); + } + + const int64 start_num_dims = start_indices_shape.dimensions(0); + if (operand_shape.rank() != start_num_dims) { + return InvalidArgument( + "Dynamic update slice start number of dimensions %d (%s) must match " + "rank %d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); + } + } else { + VLOG(2) << StrFormat("updating slice of shape %s with update shape %s", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(update_shape)); + + if (operand_shape.rank() != number_of_indices) { + return InvalidArgument( + "Dynamic update slice start number of dimensions %d must match rank " + "%d of slice input (%s).", + number_of_indices, operand_shape.rank(), + ShapeUtil::HumanString(operand_shape)); + } + + if (number_of_indices > 0) { + const Shape& first_index_shape = start_index_shapes[0]; + if (!ShapeUtil::IsScalar(first_index_shape)) { + return InvalidArgument( + "Dynamic update slice indices must be scalar, not %s.", + ShapeUtil::HumanString(first_index_shape)); + } + if (!ShapeUtil::ElementIsIntegral(first_index_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must be of integral type."); + } + for (const Shape& index_shape : start_index_shapes) { + if (!ShapeUtil::Compatible(first_index_shape, index_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must all have the same " + "shape, got mismatching indices with shapes %s and %s.", + ShapeUtil::HumanString(first_index_shape), + ShapeUtil::HumanString(index_shape)); + } + } + } } if (update_shape.rank() != operand_shape.rank()) { @@ -2268,7 +2358,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, }; // Check the shapes of computation parameters and return types. - if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { + if (!ShapeUtil::Equal(condition.result(), ShapeUtil::MakeShape(PRED, {}))) { return InvalidArgument("Condition must return a boolean; got %s.", shape_string()); } @@ -2288,7 +2378,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& predicate, const Shape& true_operand, const Shape& false_operand, const ProgramShape& true_computation, const ProgramShape& false_computation) { - if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { + if (!ShapeUtil::Equal(predicate, ShapeUtil::MakeShape(PRED, {}))) { return InvalidArgument("Predicate must be a boolean; got %s.", ShapeUtil::HumanString(predicate)); } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 1b8fd10d691498087b28ef68517868c5def1da5a..7d39ef38e05abf0a81683c1fb0f3999908b27d23 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -176,14 +176,15 @@ class ShapeInference { // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static StatusOr InferDynamicSliceShape( - const Shape& operand_shape, const Shape& start_indices_shape, - absl::Span slice_sizes); + const Shape& operand_shape, absl::Span start_index_shapes, + absl::Span slice_sizes, bool allow_scalar_indices = true); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. static StatusOr InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, - const Shape& start_indices_shape); + absl::Span start_index_shapes, + bool allow_scalar_indices = true); // Infers the shape produced by doing a compile-time-constant indexing into // the given input shape. This is essential for operations on tuples, because diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 15eb46bac0ac4e03d476ee21be6834174798526c..a95ca2bf2a8fcd700eb9234cafbfce9b62f2370c 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -130,8 +130,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { HloInstruction* new_lhs; const int64 kLhsIdx = 0; - if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) != - operand_indices.end()) { + if (absl::c_linear_search(operand_indices, kLhsIdx)) { HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx); const auto& transpose_dimensions = transpose.dimensions(); HloInstruction& transpose_operand = *transpose.mutable_operand(0); @@ -154,8 +153,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { HloInstruction* new_rhs; const int64 kRhsIdx = 1; - if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) != - operand_indices.end()) { + if (absl::c_linear_search(operand_indices, kRhsIdx)) { HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx); const auto& transpose_dimensions = transpose.dimensions(); HloInstruction& transpose_operand = *transpose.mutable_operand(0); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index b1f0672c6089035153aa48c853bface116f70026..5e505aaf02f157d0cba9dff42b1a9b89a6691504 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -55,11 +56,10 @@ bool PointsToSet::IsAmbiguous() const { bool PointsToSet::IsDistinct() const { bool distinct = true; - std::set all_points_to; - ForEachElement([&distinct, &all_points_to](const ShapeIndex& /*index*/, - const BufferList& points_to) { + absl::flat_hash_set all_points_to; + ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) { for (auto& buffer : points_to) { - if (all_points_to.count(buffer) != 0) { + if (all_points_to.contains(buffer)) { distinct = false; } all_points_to.insert(buffer); @@ -87,9 +87,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const { bool found = false; ForEachElement([&found, &buffer](const ShapeIndex& /*index*/, const BufferList& pointed_to_buffers) { - if (!found && - std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), - &buffer) != pointed_to_buffers.end()) { + if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) { found = true; } }); @@ -99,8 +97,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const { bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer, const ShapeIndex& index) const { const auto& pointed_to_buffers = element(index); - return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), - &buffer) != pointed_to_buffers.end(); + return absl::c_linear_search(pointed_to_buffers, &buffer); } void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer, @@ -604,9 +601,8 @@ bool TuplePointsToAnalysis::DoesNotUseOperandBuffer( } else if (user->opcode() == HloOpcode::kFusion && user->fusion_kind() == HloInstruction::FusionKind::kLoop) { // Find fusion parameter associated with 'operand'. - auto it = std::find_if( - user->fused_parameters().begin(), user->fused_parameters().end(), - [=](HloInstruction* fused_param) { + auto it = absl::c_find_if( + user->fused_parameters(), [&](HloInstruction* fused_param) { return user->operand(fused_param->parameter_number()) == operand; }); CHECK(it != user->fused_parameters().end()); @@ -672,9 +668,8 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( } // Find fusion parameter associated with 'operand'. const auto& fused_params = fusion->fused_parameters(); - auto fused_param_it = std::find_if( - fused_params.begin(), fused_params.end(), - [&](HloInstruction* fused_param) { + auto fused_param_it = + absl::c_find_if(fused_params, [&](HloInstruction* fused_param) { return fusion->operand(fused_param->parameter_number()) == operand; }); if (fused_param_it == fused_params.end()) { @@ -743,11 +738,10 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( // Check if one operand of kAdd fused root is kDot or kConvolution. auto* add = user->fused_expression_root(); auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot; - }); + absl::c_find_if(add->operands(), [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); if (add_operand_it == add->operands().end()) { return false; } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index d8875ca74716420ba0d54e4aba53c99001989996..7599e1e6adcff4571068c59b074fba5aaa9e9e8d 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -721,9 +721,8 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { // to fusion 'operand'. HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion, HloInstruction* operand) { - auto it = std::find_if( - fusion->fused_instructions().begin(), - fusion->fused_instructions().end(), [=](const HloInstruction* fused) { + auto it = absl::c_find_if( + fusion->fused_instructions(), [&](const HloInstruction* fused) { return fused->opcode() == HloOpcode::kParameter && fusion->operand(fused->parameter_number()) == operand; }); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index 68e2569f66bea9ec1223e454d1ead0efc7b9498e..c93a9ba3176002a34fe84a29e62075de4d19168f 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -301,7 +301,7 @@ optional ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) { /*dest_shape_index=*/{indvar_index}, /*src_shape_index=*/{})); StatusOr eval_result = - evaluator.Evaluate(*while_cond, {std::move(fake_input)}); + evaluator.Evaluate(*while_cond, {std::move(fake_input)}); if (!eval_result.ok()) { VLOG(2) << "Couldn't evaluate while loop condition."; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index a1c627a319f19c83fd676000f9de971efe377dbd..69cc8feb3f31ad782b9d3437d81d0ab8ce10aadb 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -89,7 +89,7 @@ static void CreateLoopInvariantCopy( HloInstruction* next_operand = frame->instruction->mutable_operand(frame->operand_index++); - if (hoisted_instructions->count(next_operand) || + if (hoisted_instructions->contains(next_operand) || next_operand == while_body_param) { continue; } @@ -241,7 +241,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( auto is_invariant = [&](HloInstruction* op) { return hoisted_instructions.find(op) != hoisted_instructions.end() || - unhoisted_invariant_instructions.count(op) || + unhoisted_invariant_instructions.contains(op) || op->opcode() == HloOpcode::kConstant; }; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 8e7c4bc8828552e197b41f874c070d496b85a382..3587c016b4420163a607422b1acc838646fab83a 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -299,7 +299,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // bitcast either. auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1}); auto token_shape = ShapeUtil::MakeTokenShape(); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); @@ -314,10 +314,12 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); HloInstruction* in_token = builder.AddInstruction( HloInstruction::CreateGetTupleElement(token_shape, param, 2)); - HloInstruction* bitcast_inst = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); - HloInstruction* out_token = builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, "")); + HloInstruction* bitcast_inst = + builder.AddInstruction(HloInstruction::CreateUnary( + effective_scalar_s32, HloOpcode::kBitcast, gte_0)); + HloInstruction* out_token = + builder.AddInstruction(HloInstruction::CreateOutfeed( + effective_scalar_s32, bitcast_inst, in_token, "")); builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); @@ -352,9 +354,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { // The bitcast's user can be hoisted, so hoist the bitcast too. auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); - Shape while_shape = - ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32}); + auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1}); + Shape while_shape = ShapeUtil::MakeTupleShape( + {scalar_s32, effective_scalar_s32, effective_scalar_s32}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -363,12 +365,13 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { HloInstruction* gte_0 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_f32, param, 1)); - HloInstruction* bitcast_inst = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction::CreateGetTupleElement(effective_scalar_s32, param, 1)); + HloInstruction* bitcast_inst = + builder.AddInstruction(HloInstruction::CreateUnary( + effective_scalar_s32, HloOpcode::kBitcast, gte_0)); HloInstruction* add_inst = builder.AddInstruction(HloInstruction::CreateBinary( - scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1)); + effective_scalar_s32, HloOpcode::kAdd, bitcast_inst, gte_1)); builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 772faa25e299ebd96d46545905bbc66b6e23c82e..09d54095718029541a7a25aa62f9a2e9a177960d 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -109,8 +109,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // operand appears in, but it may appear more than once! if (user->user_count() == 1 && user->users().front() == while_body_root && while_body_root->operand_index(user) == user->tuple_index() && - std::count(while_body_root->operands().begin(), - while_body_root->operands().end(), user) == 1) { + absl::c_count(while_body_root->operands(), user) == 1) { continue; } @@ -127,7 +126,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // through to the while body's root, count that element as "used", since // removing that element would be observable. for (int64 i = 0; i < while_body_root->operand_count(); ++i) { - if (used_tuple_indices.count(i)) { + if (used_tuple_indices.contains(i)) { continue; } @@ -158,7 +157,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Build up maps from the old/new to the new/old tuple indices. std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), used_tuple_indices.end()); - std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); + absl::c_sort(new_to_old_tuple_idx); absl::flat_hash_map old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 3713989ca2f64ee1d94c9f77255017909d957da2..ecca76b1e86d833c73fbb9bad6a341660a7d2669 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -407,13 +407,12 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { // The original while instruction is still left in the module as a dead // instruction, find a while instruction with a different name as the new // while instruction. + const auto& instrs = m->entry_computation()->instructions(); HloInstruction* new_while_op = - *std::find_if(m->entry_computation()->instructions().begin(), - m->entry_computation()->instructions().end(), - [&](const HloInstruction* instr) { - return (instr->opcode() == HloOpcode::kWhile && - instr->name() != "while"); - }); + *absl::c_find_if(instrs, [&](const HloInstruction* instr) { + return (instr->opcode() == HloOpcode::kWhile && + instr->name() != "while"); + }); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); EXPECT_TRUE( diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index c04a91615437d315934866448b70846ccf3bad40..1a029efe8543b5433ef5fe7923e1e804019ba0c0 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -27,7 +27,19 @@ Shape::Shape(const ShapeProto& shape_proto) { for (const int64 dimension : shape_proto.dimensions()) { add_dimensions(dimension); } - for (int i = 0; i < shape_proto.is_dynamic_dimension_size(); i++) { + // A malformed proto may have different is_dynamic_dimension_size and + // dimensions_size. Since C++ is evil, and we have no good way of bailing out + // in a constructor, conservatively trim the is_dynamic_dimension size. + // TODO(b/120111794): Make this a hard error when we have a factory method + // instead of a constructor. + if (shape_proto.dimensions_size() != + shape_proto.is_dynamic_dimension_size()) { + LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension " + "fields does not match number of dimension fields"; + } + int64 num_dynamic_dimension_fields = std::min( + shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size()); + for (int i = 0; i < num_dynamic_dimension_fields; i++) { dynamic_dimensions_[i] = shape_proto.is_dynamic_dimension(i); } tuple_shapes_.reserve(shape_proto.tuple_shapes_size()); @@ -68,19 +80,18 @@ string Shape::ToString(bool print_layout) const { } bool Shape::is_static() const { - if (ShapeUtil::IsTuple(*this)) { + if (IsTuple()) { for (const Shape& subshape : tuple_shapes_) { if (!subshape.is_static()) { return false; } } } - return !std::any_of(dynamic_dimensions_.begin(), dynamic_dimensions_.end(), - [](bool b) { return b; }); + return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; }); } void Shape::DeleteDimension(int64 dim_to_delete) { - CHECK(ShapeUtil::IsArray(*this)); + CHECK(IsArray()); CHECK_GE(dim_to_delete, 0); CHECK_LT(dim_to_delete, dimensions_.size()); dimensions_.erase(dimensions_.begin() + dim_to_delete); diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index fb5dc8fba430fdb52eb0ea6a98356032e7beef44..dc4cdc31a74d43471b72a71d9d436408e0e62deb 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -138,7 +138,7 @@ class Shape { string ShortDebugString() const { return ToProto().ShortDebugString(); } string DebugString() const { return ToProto().DebugString(); } - public: + private: // The element type of this shape (tuple, array, etc). PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 7f68c0ae1fb1e3de12b0243829f8f02cc1137894..235b065585cb6a3dd76774e358d05e0a78a2f084 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -112,6 +112,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, if (compare_layouts) { if (lhs.layout().format() != rhs.layout().format()) { + VLOG(3) << "CompareShapes: lhs layout format != rhs layout format"; return false; } if (LayoutUtil::IsDenseArray(lhs)) { @@ -145,7 +146,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, return false; } - for (int i = 0; i < ShapeUtil::Rank(lhs); ++i) { + for (int i = 0; i < lhs.rank(); ++i) { if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) { VLOG(3) << "CompareShapes: lhs and rhs have different dynamic dimensions."; @@ -207,11 +208,6 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } -/* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(shape.IsArray()) << "Non-arrays do not have a rank, shape: " << shape; - return shape.dimensions_size(); -} - /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { int64 accum = 0; for (int64 dimension : shape.dimensions()) { @@ -343,7 +339,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { CHECK(LayoutUtil::IsDenseArray(*shape)); - shape->mutable_layout()->add_minor_to_major(Rank(*shape)); + shape->mutable_layout()->add_minor_to_major(shape->rank()); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); } @@ -358,7 +354,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { - if (!IsArray(shape)) { + if (!shape.IsArray()) { return false; } return primitive_util::BitWidth(shape.element_type()) == bits; @@ -400,27 +396,24 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return primitive_util::IsFloatingPointType(shape.element_type()); } -/* static */ bool ShapeUtil::IsArray(const Shape& shape) { - return IsArrayPrimitiveType(shape.element_type()); -} - /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { - return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), - shape.tuple_shapes().end(), IsTuple); + return shape.IsTuple() && + absl::c_any_of(shape.tuple_shapes(), + [](const Shape& s) { return s.IsTuple(); }); } /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) { - return IsTuple(shape) && TupleElementCount(shape) == 0; + return shape.IsTuple() && TupleElementCount(shape) == 0; } /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { - CHECK(IsTuple(shape)) << HumanString(shape); + CHECK(shape.IsTuple()) << HumanString(shape); return shape.tuple_shapes_size(); } /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape, int64 index) { - CHECK(IsTuple(shape)); + CHECK(shape.IsTuple()); CHECK_GT(TupleElementCount(shape), index); TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index))); return shape.tuple_shapes(index); @@ -436,7 +429,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start, int64 limit) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple)); - CHECK(IsTuple(tuple)); + CHECK(tuple.IsTuple()); CHECK_LE(start, TupleElementCount(tuple)); CHECK_LE(limit, TupleElementCount(tuple)); @@ -453,15 +446,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( complex_shape.element_type())); } -/* static */ bool ShapeUtil::ShapeIs(const Shape& shape, - PrimitiveType element_type, - std::initializer_list dimensions) { - return Equal(shape, MakeShape(element_type, dimensions)); -} - /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); - DCHECK_EQ(shape.dimensions_size(), Rank(shape)); + DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape); + DCHECK_EQ(shape.dimensions_size(), shape.rank()); if (shape.dimensions().size() == 1) { return shape.dimensions()[0]; } @@ -471,8 +458,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) { - CHECK(IsArray(shape) || IsTuple(shape)); - if (IsArray(shape)) { + CHECK(shape.IsArray() || shape.IsTuple()); + if (shape.IsArray()) { return ElementsIn(shape); } int64 count = 0; @@ -505,7 +492,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (IsTuple(shape)) { + if (shape.IsTuple()) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -529,7 +516,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { - if (IsTuple(shape)) { + if (shape.IsTuple()) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -545,7 +532,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); } result += "]"; - if (!IsScalar(shape) && IsArray(shape)) { + if (!IsScalar(shape) && shape.IsArray()) { if (LayoutUtil::HasLayout(shape)) { StrAppend(&result, LayoutUtil::HumanString(shape.layout())); } @@ -580,8 +567,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - if (IsArray(lhs)) { - return IsArray(rhs) && SameDimensions(lhs, rhs); + if (lhs.IsArray()) { + return rhs.IsArray() && SameDimensions(lhs, rhs); } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), @@ -594,8 +581,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - if (IsArray(lhs)) { - return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) && + if (lhs.IsArray()) { + return rhs.IsArray() && SameElementTypeIgnoringFpPrecision(lhs, rhs) && CompatibleIgnoringElementType(lhs, rhs); } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && @@ -615,7 +602,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape, int64 dimension_number) { if (dimension_number < 0) { - dimension_number += Rank(shape); + dimension_number += shape.rank(); } CHECK_GE(dimension_number, 0); return dimension_number; @@ -669,7 +656,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( TF_DCHECK_OK(ValidateShape(shape)); if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); - } else if (IsArray(shape)) { + } else if (shape.IsArray()) { int64 byte_size = ByteSizeOfElements(shape); if (LayoutUtil::IsSparseArray(shape)) { byte_size += ByteSizeOfSparseIndices(shape); @@ -755,10 +742,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return Status::OK(); } - if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) { + if (LayoutUtil::IsSparseArray(shape) && shape.rank() == 0) { return InvalidArgument("sparse arrays must have rank > 0"); } - for (int64 i = 0; i < Rank(shape); ++i) { + for (int64 i = 0; i < shape.rank(); ++i) { int64 dimension = shape.dimensions(i); if (dimension < 0) { return InvalidArgument( @@ -774,7 +761,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) { VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape); - if (!IsArray(shape)) { + if (!shape.IsArray()) { return Status::OK(); } @@ -867,7 +854,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( ShapeIndexView index) { const Shape* subshape = &shape; for (auto i : index) { - if (!IsTuple(*subshape) || i >= subshape->tuple_shapes_size() || i < 0) { + if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) { return false; } subshape = &subshape->tuple_shapes(i); @@ -879,7 +866,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( ShapeIndexView index) { const Shape* return_shape = &shape; for (auto i : index) { - CHECK(IsTuple(*return_shape)) + CHECK(return_shape->IsTuple()) << "Invalid index " << index << " for shape " << shape; return_shape = &return_shape->tuple_shapes(i); } @@ -890,7 +877,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( const Shape& shape, ShapeIndexView index) { const Shape* return_shape = &shape; for (auto i : index) { - if (!IsTuple(*return_shape) || i < 0 || + if (!return_shape->IsTuple() || i < 0 || i >= return_shape->tuple_shapes_size()) { return InvalidArgument( "Shape index %s not a valid subshape index for tuple with shape %s", @@ -905,7 +892,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( ShapeIndexView index) { Shape* return_shape = shape; for (auto i : index) { - CHECK(IsTuple(*return_shape)); + CHECK(return_shape->IsTuple()); return_shape = return_shape->mutable_tuple_shapes(i); } return return_shape; @@ -913,11 +900,11 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { - return !IsTuple(GetSubshape(shape, index)); + return !GetSubshape(shape, index).IsTuple(); } /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { - if (!IsTuple(shape)) { + if (!shape.IsTuple()) { return 1; } int64 count = 0; @@ -1081,8 +1068,8 @@ Status ForEachMutableSubshapeHelper( /* static */ std::tuple, std::vector> ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, const Shape& shape_post) { - CHECK(IsArray(shape_pre)); - CHECK(IsArray(shape_post)); + CHECK(shape_pre.IsArray()); + CHECK(shape_post.IsArray()); auto nil = std::make_tuple(false, std::vector(), std::vector()); @@ -1129,7 +1116,7 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, auto unmodified_dim_pair = i < unmodified_dims.size() ? unmodified_dims[i] - : std::make_pair(Rank(shape_pre), Rank(shape_post)); + : std::make_pair(shape_pre.rank(), shape_post.rank()); if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) { return nil; } @@ -1141,8 +1128,8 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, /* static */ std::vector> ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); // Unmodified dimensions are merely common factors of rank 1. auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), @@ -1192,8 +1179,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); CHECK(LayoutUtil::HasLayout(input_shape)); CHECK(LayoutUtil::HasLayout(output_shape)); @@ -1321,12 +1308,12 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, Shape output_shape_dim0_major = MakeShapeWithDescendingLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions())); - for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) { + for (int64 input_dim = 0; input_dim < input_shape.rank(); ++input_dim) { if (input_shape.dimensions(input_dim) <= 1) { continue; } - std::vector input_unit_index(Rank(input_shape), 0); + std::vector input_unit_index(input_shape.rank(), 0); input_unit_index[input_dim] = 1; int64 logical_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major, @@ -1352,11 +1339,11 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ absl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); - int64 input_rank = Rank(input_shape); - int64 output_rank = Rank(output_shape); + int64 input_rank = input_shape.rank(); + int64 output_rank = output_shape.rank(); // First, calculate an alignment of the dimensions. A consecutive sequence of // input dimensions and output dimensions belong to the same alignment part if @@ -1493,14 +1480,14 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { - CHECK(IsArray(shape)); + CHECK(shape.IsArray()); shape.DeleteDimension(dim_to_delete); return shape; } /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { - CHECK(IsArray(shape)); + CHECK(shape.IsArray()); std::vector dims_to_delete; for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { if (!p(i)) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index c8295e85ce15722a3da7ab5585e6d499d3b5efa2..e98c6e024bec1f6db5c40d3cd3215ca44eb13698 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -185,7 +185,7 @@ class ShapeUtil { // may not actually be able to store this number of elements. See // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of // elements that can be stored in a sparse shape. - // Precondition: IsArray(shape) + // Precondition: shape.IsArray() static int64 ElementsIn(const Shape& shape); // As ElementsIn(), but recurses through tuples. @@ -296,11 +296,6 @@ class ShapeUtil { // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); - // Returns the rank (number of dimensions) of the given shape. - // Precondition: !IsTuple(shape) - ABSL_DEPRECATED("Use `Shape::rank` instead.") - static int64 Rank(const Shape& shape); - // Returns the number of dimensions for which the dimension is not (trivially) // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just // fluff. Note that zero dimensions are included in the true rank, e.g., @@ -314,10 +309,10 @@ class ShapeUtil { // Scalar-specific static bool IsScalar(const Shape& shape) { - return IsArray(shape) && Rank(shape) == 0; + return shape.IsArray() && shape.rank() == 0; } static bool IsEffectiveScalar(const Shape& shape) { - return IsArray(shape) && TrueRank(shape) == 0; + return shape.IsArray() && TrueRank(shape) == 0; } // Returns whether "shape" is a scalar (array) with the given element_type. @@ -457,31 +452,6 @@ class ShapeUtil { // that floating point numbers are signed. static bool ElementIsSigned(const Shape& shape); - // Returns whether the shape is a tuple. - ABSL_DEPRECATED("Use Shape::IsTuple instead.") - static bool IsTuple(const Shape& shape) { - return shape.element_type() == TUPLE; - } - - // Returns whether the shape is an opaque value (i.e. an 'existential' typed - // value that is passed to CustomCall operations). - ABSL_DEPRECATED("Use Shape::IsOpaque instead.") - static bool IsOpaque(const Shape& shape) { - return shape.element_type() == OPAQUE; - } - - // Returns whether the shape is an token value used for ordering - // side-effecting operations. - ABSL_DEPRECATED("Use Shape::IsToken instead.") - static bool IsToken(const Shape& shape) { - return shape.element_type() == TOKEN; - } - - // Returns whether the shape is an array. Note that scalars are considered - // arrays. - ABSL_DEPRECATED("Use Shape::IsArray instead.") - static bool IsArray(const Shape& shape); - // Returns whether the given primitive type corresponds to an array shape. static bool IsArrayPrimitiveType(PrimitiveType primitive_type); @@ -511,12 +481,6 @@ class ShapeUtil { // shape. static Shape ComplexComponentShape(const Shape& complex_shape); - // Shorthand for testing whether a shape is of a given element type and - // sequence of dimensions. - ABSL_DEPRECATED("Use Equal() instead.") - static bool ShapeIs(const Shape& shape, PrimitiveType element_type, - std::initializer_list dimensions); - // Returns true if the given shape has a subshape at the given index. static bool IndexIsValid(const Shape& shape, ShapeIndexView index); @@ -764,7 +728,7 @@ class ShapeUtil { if (ShapeUtil::IsZeroElementArray(shape)) { return Status::OK(); } - CHECK_EQ(Rank(shape), base.size()); + CHECK_EQ(shape.rank(), base.size()); CHECK_EQ(incr.size(), base.size()); CHECK_EQ(count.size(), base.size()); const int64 rank = LayoutUtil::MinorToMajor(shape).size(); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 8e7c20819335bba161185bcf1237f709cfdb2a5d..61b4e73e060c18a3d0108e68d1117607d6c11c0f 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -538,10 +538,6 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); } -TEST(ShapeUtilTest, ShapeIs) { - EXPECT_FALSE(ShapeUtil::ShapeIs(ShapeUtil::MakeShape(PRED, {2}), PRED, {})); -} - TEST(ShapeUtilTest, ForEachIndex) { struct ShapeDimensionAndNumberInvocations { std::vector dimensions; diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h index a96d483462efd77ae4761541e8c79b2c84fa49f3..0c25355467da3fd346d80db790d78252869975ef 100644 --- a/tensorflow/compiler/xla/sparse_index_array.h +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -135,7 +135,7 @@ void SparseIndexArray::SortWithValues(absl::Span values) { auto sort_order_less = [this](int64 lhs, int64 rhs) { return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0; }; - std::sort(sort_order.begin(), sort_order.end(), sort_order_less); + absl::c_sort(sort_order, sort_order_less); // Reorder the array elements according to sort_order. Work through the array // and follow cycles so we can do the reorder in-place. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index ee24d4d99cb1f7ce51a72c6258cbadd6adf12f81..0fd0fc108a6c5432cf8f3a74006949d5a46b2b99 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -71,6 +71,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", @@ -276,9 +277,6 @@ cc_library( xla_test( name = "bad_rng_shape_validation_test", srcs = ["bad_rng_shape_validation_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -344,9 +342,6 @@ xla_test( xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -367,9 +362,6 @@ xla_test( xla_test( name = "query_inferred_shape_test", srcs = ["query_inferred_shape_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -387,9 +379,6 @@ xla_test( xla_test( name = "while_test", srcs = ["while_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -413,6 +402,10 @@ xla_test( xla_test( name = "xla_hlo_profile_test", srcs = ["xla_hlo_profile_test.cc"], + blacklisted_backends = [ + # Hlo profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", @@ -436,9 +429,6 @@ xla_test( xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -453,7 +443,6 @@ xla_test( xla_test( name = "map_test", srcs = ["map_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -506,9 +495,6 @@ xla_test( xla_test( name = "pred_test", srcs = ["pred_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla/client:local_client", @@ -524,9 +510,6 @@ xla_test( xla_test( name = "select_test", srcs = ["select_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -544,7 +527,6 @@ xla_test( xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -562,7 +544,6 @@ xla_test( xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -623,9 +604,6 @@ xla_test( xla_test( name = "deconstruct_tuple_test", srcs = ["deconstruct_tuple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -648,7 +626,6 @@ xla_test( name = "array_elementwise_ops_test", srcs = ["array_elementwise_ops_test.cc"], shard_count = 25, - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -698,7 +675,6 @@ xla_test( xla_test( name = "reduce_precision_test", srcs = ["reduce_precision_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -725,7 +701,6 @@ xla_test( srcs = ["dot_operation_test.cc"], shard_count = 20, tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -736,6 +711,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -793,6 +769,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -806,9 +783,6 @@ xla_test( xla_test( name = "transpose_test", srcs = ["transpose_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -828,9 +802,6 @@ xla_test( xla_test( name = "constants_test", srcs = ["constants_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -951,6 +922,11 @@ xla_test( xla_test( name = "batch_normalization_test", srcs = ["batch_normalization_test.cc"], + blacklisted_backends = [ + # BatchNorm HLOs are not handled by the interpreter backend, and the + # BatchNorm expander is not run on the interpreter. + "interpreter", + ], shard_count = 40, deps = [ ":test_utils", @@ -1042,9 +1018,6 @@ xla_test( name = "slice_test", srcs = ["slice_test.cc"], shard_count = 40, - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -1065,9 +1038,6 @@ xla_test( xla_test( name = "multidimensional_slice_test", srcs = ["multidimensional_slice_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1085,9 +1055,6 @@ xla_test( name = "dynamic_ops_test", timeout = "moderate", srcs = ["dynamic_ops_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -1113,9 +1080,6 @@ xla_test( xla_test( name = "tuple_test", srcs = ["tuple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -1139,9 +1103,6 @@ xla_test( xla_test( name = "vector_ops_reduce_test", srcs = ["vector_ops_reduce_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1162,7 +1123,6 @@ xla_test( srcs = ["reduce_test.cc"], shard_count = 40, tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -1229,7 +1189,6 @@ xla_test( srcs = [], shard_count = 20, tags = [ - "enable_for_xla_interpreter", "optonly", ], xla_test_library_deps = [":reduce_window_test_library"], @@ -1241,7 +1200,6 @@ xla_test( timeout = "long", srcs = ["select_and_scatter_test.cc"], tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -1267,9 +1225,6 @@ xla_test( xla_test( name = "copy_test", srcs = ["copy_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla:array2d", @@ -1290,9 +1245,6 @@ xla_test( xla_test( name = "reduce_hlo_test", srcs = ["reduce_hlo_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1306,9 +1258,6 @@ xla_test( xla_test( name = "token_hlo_test", srcs = ["token_hlo_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", @@ -1323,9 +1272,6 @@ xla_test( xla_test( name = "call_test", srcs = ["call_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -1368,9 +1314,6 @@ xla_test( xla_test( name = "binop_scaling_test", srcs = ["binop_scaling_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1388,9 +1331,6 @@ xla_test( xla_test( name = "broadcast_simple_test", srcs = ["broadcast_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1410,9 +1350,6 @@ xla_test( xla_test( name = "pad_test", srcs = ["pad_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1434,9 +1371,6 @@ xla_test( xla_test( name = "fmax_test", srcs = ["fmax_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1451,9 +1385,6 @@ xla_test( xla_test( name = "log_test", srcs = ["log_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1468,9 +1399,6 @@ xla_test( xla_test( name = "matrix_ops_simple_test", srcs = ["matrix_ops_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -1517,9 +1445,6 @@ xla_test( name = "reshape_test", srcs = ["reshape_test.cc"], shard_count = 30, - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1545,9 +1470,6 @@ xla_test( xla_test( name = "reverse_test", srcs = ["reverse_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1566,9 +1488,6 @@ xla_test( xla_test( name = "vector_ops_simple_test", srcs = ["vector_ops_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:shape_util", @@ -1592,9 +1511,6 @@ xla_test( xla_test( name = "concat_test", srcs = ["concat_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1615,9 +1531,6 @@ xla_test( xla_test( name = "convert_test", srcs = ["convert_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1637,6 +1550,10 @@ xla_test( xla_test( name = "all_reduce_test", srcs = ["all_reduce_test.cc"], + blacklisted_backends = [ + # All reduce is not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1661,9 +1578,6 @@ xla_test( xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1703,9 +1617,6 @@ xla_test( xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1767,6 +1678,10 @@ xla_test( xla_test( name = "execution_profile_test", srcs = ["execution_profile_test.cc"], + blacklisted_backends = [ + # Execution profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", @@ -1781,6 +1696,10 @@ xla_test( name = "execution_profile_test_with_xla_hlo_profile", srcs = ["execution_profile_test.cc"], args = ["--xla_hlo_profile"], + blacklisted_backends = [ + # Hlo profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", @@ -1794,9 +1713,6 @@ xla_test( xla_test( name = "replay_test", srcs = ["replay_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", @@ -1819,9 +1735,6 @@ xla_test( xla_test( name = "broadcast_test", srcs = ["broadcast_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1883,9 +1796,6 @@ xla_test( xla_test( name = "fusion_test", srcs = ["fusion_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -2003,6 +1913,10 @@ xla_test( xla_test( name = "outfeed_in_nested_computation_test", srcs = ["outfeed_in_nested_computation_test.cc"], + blacklisted_backends = [ + # Outfeed ops are not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla/tests:local_client_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -2179,7 +2093,6 @@ xla_test( srcs = ["iota_test.cc"], shard_count = 30, tags = [ - "enable_for_xla_interpreter", # Require optimized builds, iota_test_cpu is very slow in fastbuild. "optonly", ], @@ -2207,3 +2120,18 @@ tf_cc_test( "@com_google_absl//absl/synchronization", ], ) + +xla_test( + name = "ptxas_bug_120501638", + srcs = ["ptxas_bug_120501638.cc"], + tags = [ + # Disabled in OSS until nvidia publicly releases a fixed ptxas. + "no_oss", + ], + deps = [ + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:test", + ], +) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 915b456b52215f8d6a9eb6c5b933f3502f1d3d2c..e4cc8b41991927dab815fac2153e525205aad4c8 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -2047,6 +2047,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto minimum = ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN}); + auto argument = + ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); + auto maximum = ConstantR1(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f}); + Clamp(minimum, argument, maximum); + + ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {}, + error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XlaBuilder builder(TestName()); auto minimum = ConstantR0(&builder, 0.0f); diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index e9728e636f0ee032416b2da17a3ea83c5bb18083..63e48117056dec4af603cbc85e478fcb15ad0cec 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -76,7 +76,9 @@ XLA_TEST_F(Bfloat16Test, NegateScalarF16) { error_spec_); } -XLA_TEST_F(Bfloat16Test, BatchNormTraining) { +// Disabled on interpreter since BatchNormExanper is not run by default on the +// intepreter backend. +XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormTraining)) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); @@ -110,7 +112,9 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); } -XLA_TEST_F(Bfloat16Test, BatchNormGrad) { +// Disabled on interpreter since BatchNormExanper is not run by default on the +// intepreter backend. +XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormGrad)) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index d5b3a4d14f932d76b2e7b198ee233756c94a2694..247328b730f3af936d933f824da491b593b27c90 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -109,7 +109,10 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { /*minor_to_major=*/{1, 0}))); } -XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { +// Disabled for interpreter since ExecuteAsyncOnStream is not implemented on +// interpreter backend. +XLA_TEST_F(ClientTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(ExecuteParallel))) { XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 249693891290e14645ee5b4b4d97b2d506a01302..9db9f2563b636c4f929585eb13a9c7f809833eda 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -467,8 +467,8 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { // servers. The error message is missing the operator ++. template void iota_int_init_value(std::vector& values, int init_value) { - std::for_each(values.begin(), values.end(), - [&](T& value) { value = static_cast(init_value++); }); + absl::c_for_each(values, + [&](T& value) { value = static_cast(init_value++); }); } template diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index c5d8b663f4abe77e05ec213d2e4e075c260a8655..2e02968ac5f05c60ed9488d848f6c7fc387b41b4 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -1147,5 +1148,38 @@ XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { ComputeAndCompareR2(&builder, expected, {}, error_spec_); } + +class DotOperationTextTest : public HloTestBase {}; + +XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) { + absl::string_view hlo_string = + R"( +HloModule ComplexDotMultipleNonContracting + +ENTRY %test { + %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0) + %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1) + ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) { + absl::string_view hlo_string = + R"( +HloModule ComplexDotMultipleNonContracting + +ENTRY %test { + %lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0) + %rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1) + ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index dcb469087e0064d17ce3b04fdeaf0b6136069a55..1b0bebe2d03a9a153cd0c80329ed0c49c91333a3 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -48,7 +48,7 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { tensorflow::SubProcess file_check_process; file_check_process.SetProgram(file_check_path, - {file_check_path, pattern_path}); + {file_check_path, "-v", pattern_path}); file_check_process.SetChannelAction(tensorflow::CHAN_STDIN, tensorflow::ACTION_PIPE); file_check_process.SetChannelAction(tensorflow::CHAN_STDERR, @@ -71,9 +71,7 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { LOG(WARNING) << "NOTE: FileCheck binary does not exist!"; } - LOG(WARNING) << "FileCheck error: " << standard_error; - LOG(WARNING) << "FileCheck input was:"; - XLA_LOG_LINES(tensorflow::WARNING, input); + LOG(WARNING) << "FileCheck error:\n" << standard_error; LOG(WARNING) << "FileCheck pattern was:"; XLA_LOG_LINES(tensorflow::WARNING, pattern); } else if (!standard_error.empty()) { diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index daa89398a697af9149797d621c3bdca80a00aedd..d65b67a535d43553a3a94f76482ad4618f9b8aab 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -600,7 +600,9 @@ ENTRY main { class GatherClientLibraryTest : public ClientLibraryTestBase {}; -XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { +// Disabled on interpreter since ExectuteAsyncOnStream is not supported. +XLA_TEST_F(GatherClientLibraryTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(Basic))) { // We create this HLO, but using the XlaBuilder API. // // ENTRY main { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d57846e19bb80c5b9c87d50596da2915f9aef317..66f72ba8d20b8ef1f436da4425b2bb6518ee9a94 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -139,7 +139,8 @@ std::unique_ptr HloTestBase::CreateNewVerifiedModule( const string& name) { return absl::make_unique( name, GetModuleConfigForTest(), verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); + allow_mixed_precision_in_hlo_verifier_, + backend().compiler()->ShapeSizeBytesFunction()); } StatusOr> @@ -147,7 +148,8 @@ HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, const HloModuleConfig& config) { auto module = absl::make_unique( TestName(), config, verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); + allow_mixed_precision_in_hlo_verifier_, + backend().compiler()->ShapeSizeBytesFunction()); TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); TF_RETURN_IF_ERROR(module->Verify()); return std::move(module); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 1d1e7f437296a7493ef7da07039fcf6d273f35bc..69a4f96288c7285010e9adbdc33f1b394f58d8d2 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -46,10 +46,12 @@ class VerifiedHloModule : public HloModule { public: VerifiedHloModule(const string& name, const HloModuleConfig& config, bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function shape_size_function) : HloModule(name, config), - verifier_(verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} + verifier_( + verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier, + /*instruction_can_change_layout_func=*/{}, shape_size_function) {} ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 6522c563252a0e8271bf733668f39bcf00a80d06..96527886b718bc1ea4ce8cc2d7dbeb2e3ef1d1eb 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -842,7 +842,8 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { LiteralUtil::CreateR0(123456789000LL)})); } -XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { +// Disabled on interpreter backend since infeed HLO is unsupported. +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); auto in = Infeed(&builder, shape); @@ -867,7 +868,8 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } -XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { +// Disabled on interpreter backend since infeed/outfeed HLOs are unsupported. +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); auto in = Infeed(&builder, shape); diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 8f2c26f0eea9c7a3b33cd77e5977924c1659535a..e49bcf26bd6e50f8fb36c86f217907b5d4901eae 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -80,7 +80,9 @@ XLA_TEST_F(PrngTest, LargeU01) { UniformTest(0, 1, {0x100, 0x100}); } XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest(5, 24, {12}); } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) { +// TODO(b/122047800): Interpreter does not support BF16 for RNG ops. +XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER( + DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests)))) { for (int64 seed = 0; seed < 100; ++seed) { // The largest negative number smaller than zero in bf16 that's not // denormalized. @@ -103,7 +105,9 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) { } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { +// TODO(b/122047800): Interpreter does not support BF16 for RNG ops. +XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( + DISABLED_ON_CPU(ScalarBF16CountTests)))) { // There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75, // they should get similar counts. bfloat16 low = static_cast(32.25); @@ -276,6 +280,39 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6)); } +// This test verifies that the two RNG instructions with the same parameters in +// the same HloComputation produces different values. +XLA_TEST_F(PrngTest, DifferentValuesForIdenticalRngNodesInSameComputation) { + // Build a U[0,1) computation. + auto build_computation = [this]() { + XlaBuilder builder(TestName()); + auto a = RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 100), + ShapeUtil::MakeShape(S32, {10})); + auto b = RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 100), + ShapeUtil::MakeShape(S32, {10})); + Tuple(&builder, {a, b}); + return builder.Build(); + }; + + ExecutionOptions execution_options = execution_options_; + execution_options.set_seed(42); + + Literal result_tuple; + { + TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); + TF_ASSERT_OK_AND_ASSIGN( + result_tuple, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options)); + } + + auto results = result_tuple.DecomposeTuple(); + ASSERT_EQ(results.size(), 2); + + EXPECT_FALSE(LiteralTestUtil::Equal(results[0], results[1])); +} + XLA_TEST_F(PrngTest, TenValuesN01) { XlaBuilder builder(TestName()); RngNormal(ConstantR0(&builder, 0), ConstantR0(&builder, 1), diff --git a/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc b/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e5d7db97e88936e7336ed02a5c7a1171254b0cf --- /dev/null +++ b/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class PtxasBugTest : public HloTestBase {}; + +// Checks for a bug in ptxas, tracked as Google bug 120501638, and nvidia bug +// 2459377. We never received an explanation of what exactly was going wrong +// here in ptxas. Known-bad in ptxas 10.0.145, known-good in ptxas 10.0.249. +TEST_F(PtxasBugTest, DoIt) { + const char* const kModuleStr = R"( +HloModule test + +add_F32.14 { + lhs.15 = f32[] parameter(0) + rhs.16 = f32[] parameter(1) + ROOT add.17 = f32[] add(lhs.15, rhs.16) +} + +ENTRY testcase { + arg0.1 = f32[2,5,2]{2,1,0} parameter(0) + reshape.2 = f32[2,5,2]{2,1,0} reshape(arg0.1) + constant.3 = f32[] constant(0) + pad.4 = f32[2,6,2]{2,1,0} pad(reshape.2, constant.3), padding=0_0x0_1x0_0 + reshape.5 = f32[2,3,2,2]{3,2,1,0} reshape(pad.4) + transpose.6 = f32[2,2,3,2]{3,0,2,1} transpose(reshape.5), dimensions={2,0,1,3} + reshape.7 = f32[4,3,2]{2,1,0} reshape(transpose.6) + reshape.8 = f32[4,1,3,2]{3,2,1,0} reshape(reshape.7) + transpose.9 = f32[4,2,1,3]{1,3,2,0} transpose(reshape.8), dimensions={0,3,1,2} + convert.10 = f32[4,2,1,3]{1,3,2,0} convert(transpose.9) + constant.12 = f32[] constant(0) + pad.13 = f32[4,2,1,3]{3,2,1,0} pad(convert.10, constant.12), padding=0_0x0_0x0_0x0_0 + constant.11 = f32[] constant(0) + reduce-window.18 = f32[4,2,1,3]{3,2,1,0} reduce-window(pad.13, constant.11), + window={size=1x1x1x1}, to_apply=add_F32.14 + constant.19 = f32[] constant(1) + broadcast.20 = f32[4,2,1,3]{3,2,1,0} broadcast(constant.19), dimensions={} + divide.21 = f32[4,2,1,3]{3,2,1,0} divide(reduce-window.18, broadcast.20) + convert.22 = f32[4,2,1,3]{3,2,1,0} convert(divide.21) + transpose.23 = f32[4,1,3,2]{2,1,3,0} transpose(convert.22), dimensions={0,2,3,1} + reshape.24 = f32[4,3,2]{2,1,0} reshape(transpose.23) + reshape.25 = f32[2,2,3,2]{3,2,1,0} reshape(reshape.24) + transpose.26 = f32[2,3,2,2]{3,1,0,2} transpose(reshape.25), dimensions={1,2,0,3} + reshape.27 = f32[2,6,2]{2,1,0} reshape(transpose.26) + slice.28 = f32[2,5,2]{2,1,0} slice(reshape.27), slice={[0:2], [0:5], [0:2]} + reshape.29 = f32[2,5,2]{2,1,0} reshape(slice.28) + tuple.30 = (f32[2,5,2]{2,1,0}) tuple(reshape.29) + ROOT get-tuple-element.31 = f32[2,5,2]{2,1,0} get-tuple-element(tuple.30), index=0 +})"; + + // Create a module with the true-default flags, not the default-for-testing + // flags. In particular, true-default flags enable unrolling, whereas for + // testing we disable unrolling, and this bug doesn't trigger without + // unrolling. + HloModuleConfig config; + config.set_debug_options(DefaultDebugOptionsIgnoringFlags()); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01, 0.01})); +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 7ca99a91635e85cd0888e59ecde31e47fec21844..80a6868485c9162d1cb0de24f0adf3f1c1d2503a 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -79,30 +79,28 @@ string PrependDisabledIfIndicated(const string& test_case_name, // heuristic to decide whether the test case should be disabled, and we // determine whether the test case should be disabled by resolving the (test // case name, test name) in a manifest file. -#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class, parent_id) \ - class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ - : public parent_class { \ - public: \ - GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ - \ - private: \ - virtual void TestBody(); \ - static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ - GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)); \ - }; \ - \ - ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)::test_info_ = \ - ::testing::internal::MakeAndRegisterTestInfo( \ - #test_case_name, \ - ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ - .c_str(), \ - nullptr, nullptr, \ - ::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \ - parent_class::SetUpTestCase, parent_class::TearDownTestCase, \ - new ::testing::internal::TestFactoryImpl); \ +#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \ + class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ + : public parent_class { \ + public: \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ + \ + private: \ + virtual void TestBody(); \ + static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)); \ + }; \ + \ + ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)::test_info_ = \ + ::testing::RegisterTest( \ + #test_case_name, \ + ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ + .c_str(), \ + nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \ + return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \ + }); \ void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() // This is identical to the TEST_F macro from "gtest", but it potentially @@ -111,9 +109,8 @@ string PrependDisabledIfIndicated(const string& test_case_name, // Per usual, you can see what tests are available via --gunit_list_tests and // choose to run tests that have been disabled via the manifest via // --gunit_also_run_disabled_tests. -#define XLA_TEST_F(test_fixture, test_name) \ - XLA_GTEST_TEST_(test_fixture, test_name, test_fixture, \ - ::testing::internal::GetTypeId()) +#define XLA_TEST_F(test_fixture, test_name) \ + XLA_GTEST_TEST_(test_fixture, test_name, test_fixture) // Likewise, this is identical to the TEST_P macro from "gtest", but // potentially disables the test based on the DISABLED_MANIFEST file. diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 290886624de9f9fcca533d01f15d4a3e4c23d7ee..95c89b0ba6f29c453abab88e29bca13ee006455a 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -274,16 +275,9 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -Literal MakeRandomIndex(absl::Span index_space, - std::minstd_rand0* engine) { - std::vector start_indices(index_space.size()); - if (engine != nullptr) { - for (int i = 0; i < index_space.size(); ++i) { - std::uniform_int_distribution generator(0, index_space[i]); - start_indices[i] = generator(*engine); - } - } - return LiteralUtil::CreateR1(start_indices); +Literal MakeRandomIndex(int64 index_bound, std::minstd_rand0* engine) { + std::uniform_int_distribution generator(0, index_bound); + return LiteralUtil::CreateR0(generator(*engine)); } // Use dataflow analysis on each parameter to see if there are uses that would @@ -300,8 +294,8 @@ std::vector FindConstrainedUses( HloInstruction* instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; - if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) || - (opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) { + if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) || + (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) { constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kFusion) { const HloInstruction* const to_analyze = @@ -336,7 +330,7 @@ std::vector FindConstrainedUses( StatusOr CreateLiteralForConstrainedUses( const absl::Span constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { - std::vector index_space; + int64 index_bound = INT64_MAX; bool no_duplicates = false; bool needs_constant = false; ConstantType constant_type = ConstantType::kUnknown; @@ -348,19 +342,16 @@ StatusOr CreateLiteralForConstrainedUses( const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice ? use->shape() : use->operand(1)->shape(); - const int64 rank = indexed_shape.rank(); - if (!index_space.empty()) { - TF_RET_CHECK(rank == index_space.size()); - for (int64 i = 0; i < rank; ++i) { - index_space[i] = std::min( - index_space[i], ShapeUtil::GetDimension(indexed_shape, i) - - ShapeUtil::GetDimension(slice_shape, i)); - } - } else { - index_space.resize(rank); - for (int64 i = 0; i < rank; ++i) { - index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); + const int64 first_index = + Cast(use)->first_index_operand_number(); + for (int64 operand = first_index; operand < use->operand_count(); + ++operand) { + if (use->operand(operand) == ¶m) { + index_bound = std::min( + index_bound, + ShapeUtil::GetDimension(indexed_shape, operand - first_index) - + ShapeUtil::GetDimension(slice_shape, + operand - first_index)); } } break; @@ -388,13 +379,14 @@ StatusOr CreateLiteralForConstrainedUses( } int constraint_count = 0; constraint_count += no_duplicates ? 1 : 0; - constraint_count += !index_space.empty() ? 1 : 0; + constraint_count += (index_bound != INT64_MAX) ? 1 : 0; constraint_count += needs_constant ? 1 : 0; if (constraint_count > 1) { return Unimplemented("Conflicting operand generation constraints."); } - if (!index_space.empty()) { - return MakeRandomIndex(index_space, engine); + if (index_bound != INT64_MAX) { + return MakeRandomIndex(index_bound, engine) + .Reshape(param.shape().dimensions()); } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 448a66cfdd897b17cce1c87c050520a2f2eb0ea2..591d6c19228a313f530cdae18f4be37e7b517601 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -79,25 +79,26 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { R"(HloModule index_space_module ENTRY IndexSpace { - index_param = s32[3]{0} parameter(0) - array_param.1 = f32[123,4,789]{0,1,2} parameter(1) - array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) - dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3} - ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} + index_param.0 = s32[] parameter(0) + index_param.1 = s32[] parameter(1) + index_param.2 = s32[] parameter(2) + array_param.1 = f32[123,4,789]{0,1,2} parameter(3) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(4) + dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={1,2,3} + ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={3,2,2} })") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); - ASSERT_EQ(args.size(), 3); - const Literal& index_arg = args[0]; + ASSERT_EQ(args.size(), 5); - EXPECT_EQ(index_arg.Get({0}), 0); + EXPECT_EQ(args[0].Get({}), 0); - EXPECT_GE(index_arg.Get({1}), 0); - EXPECT_LE(index_arg.Get({1}), 2); + EXPECT_GE(args[1].Get({}), 0); + EXPECT_LE(args[0].Get({}), 2); - EXPECT_GE(index_arg.Get({2}), 0); - EXPECT_LE(index_arg.Get({2}), 3); + EXPECT_GE(args[2].Get({}), 0); + EXPECT_LE(args[2].Get({}), 3); } XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { @@ -105,28 +106,29 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { R"(HloModule index_space_module ENTRY IndexSpace { - index_param = s32[3]{0} parameter(0) - array_param.1 = f32[123,4,789]{0,1,2} parameter(1) - array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) - update_param.1 = f32[1,2,3]{0,1,2} parameter(3) - update_param.2 = f32[3,2,2]{0,1,2} parameter(4) - - dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param) - ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) + index_param.0 = s32[] parameter(0) + index_param.1 = s32[] parameter(1) + index_param.2 = s32[] parameter(2) + array_param.1 = f32[123,4,789]{0,1,2} parameter(3) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(4) + update_param.1 = f32[1,2,3]{0,1,2} parameter(5) + update_param.2 = f32[3,2,2]{0,1,2} parameter(6) + + dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param.0, index_param.1, index_param.2) + ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param.0, index_param.1, index_param.2) })") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); - ASSERT_EQ(args.size(), 5); - const Literal& index_arg = args[0]; + ASSERT_EQ(args.size(), 7); - EXPECT_EQ(index_arg.Get({0}), 0); + EXPECT_EQ(args[0].Get({}), 0); - EXPECT_GE(index_arg.Get({1}), 0); - EXPECT_LE(index_arg.Get({1}), 2); + EXPECT_GE(args[1].Get({}), 0); + EXPECT_LE(args[0].Get({}), 2); - EXPECT_GE(index_arg.Get({2}), 0); - EXPECT_LE(index_arg.Get({2}), 3); + EXPECT_GE(args[2].Get({}), 0); + EXPECT_LE(args[2].Get({}), 3); } XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { @@ -198,5 +200,33 @@ ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,14 } } +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsR0InputToDynamicSlice) { + auto module = ParseHloString(R"( +HloModule Test + +ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] { + %parameter.1 = f32[20,20]{1,0} parameter(1) + %constant.1 = s32[1]{0} constant({0}) + %parameter.0 = s32[] parameter(0) + %bitcast.3 = s32[1]{0} bitcast(s32[] %parameter.0) + %concatenate.1 = s32[2]{0} concatenate(s32[1]{0} %constant.1, s32[1]{0} %bitcast.3), dimensions={0} + %dynamic-slice.2 = f32[20,1]{1,0} dynamic-slice(f32[20,20]{1,0} %parameter.1, s32[2]{0} %concatenate.1), dynamic_slice_sizes={20,1} + %bitcast.4 = f32[20]{0} bitcast(f32[20,1]{1,0} %dynamic-slice.2) + %dynamic-slice.3 = f32[1]{0} dynamic-slice(f32[20]{0} %bitcast.4, s32[1]{0} %bitcast.3), dynamic_slice_sizes={1} + ROOT %bitcast.5 = f32[] bitcast(f32[1]{0} %dynamic-slice.3) +} +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + EXPECT_TRUE(ShapeUtil::Equal(args[0].shape(), ShapeUtil::MakeShape(S32, {}))) + << ShapeUtil::HumanString(args[0].shape()); + EXPECT_TRUE( + ShapeUtil::Equal(args[1].shape(), ShapeUtil::MakeShape(F32, {20, 20}))) + << ShapeUtil::HumanString(args[1].shape()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 9c586bdeb05afb7378e92caed1f3edc408e051bf..426d6c84ee3a9e30bdba1da7ae570ed0279a8440 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -176,8 +176,9 @@ XLA_TEST_F(TupleTest, AddTupleElements) { {2.f, 4.f, 6.f}, // row 0 {5.f, 7.f, 9.f}, // row 1 }); - ASSERT_TRUE(ShapeUtil::ShapeIs(vector_shape, F32, {3})); - ASSERT_TRUE(ShapeUtil::ShapeIs(matrix_shape, F32, {/*y=*/2, /*x=*/3})); + ASSERT_TRUE(ShapeUtil::Equal(vector_shape, ShapeUtil::MakeShape(F32, {3}))); + ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, + ShapeUtil::MakeShape(F32, {/*y=*/2, /*x=*/3}))); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 1538f2afbafad53fa2f7a3fd1a4705e5ab55f536..c7337e8caae8f2ee25f4b25dc22439e08d2ecc25 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -174,9 +174,8 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, exec_run_options.set_allocator(backend->memory_allocator()); exec_run_options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); - ServiceExecutableRunOptions run_options( - exec_run_options, /*borrow_stream=*/nullptr, - backend->eigen_intra_op_thread_pool()); + ServiceExecutableRunOptions run_options(exec_run_options, + /*borrow_stream=*/nullptr); std::vector args = {&lhs_arg, &rhs_arg}; TF_ASSERT_OK_AND_ASSIGN( auto execution_result, diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 68cab7387cf1576072f96878b50f07def6862d8b..34b73b5206fa20d6dff7567afd78fd89897c8c33 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -86,7 +86,7 @@ bool IsPermutation(absl::Span permutation, int64 rank) { CHECK_LT(index, rank); output[index] = 0; } - return std::find(output.begin(), output.end(), -1) == output.end(); + return !absl::c_linear_search(output, -1); } std::vector InversePermutation( diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 6722641e9d2c177440361e6f0d1f6c0804eb7cda..f2fd17dc99455a921bf875aad2a3661b4d456823 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -324,8 +324,7 @@ bool IsIdentityPermutation(absl::Span permutation); template int64 PositionInContainer(const Container& container, int64 value) { - return std::distance(container.begin(), - std::find(container.begin(), container.end(), value)); + return std::distance(container.begin(), absl::c_find(container, value)); } // Formats the container as a comma-separated string. StrAppend must support diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 51c73b3d17e4c32d9a8a14d3055ab56f02922af3..e001cc35f9fcea2783b3952e825838af6bbece72 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -137,25 +138,23 @@ bool HasPadding(const Window& window) { } bool HasSymmetricPadding(const Window& window) { - return std::all_of(window.dimensions().begin(), window.dimensions().end(), - [](const WindowDimension& dim) { - return dim.padding_low() == dim.padding_high(); - }); + return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) { + return dim.padding_low() == dim.padding_high(); + }); } bool HasSymmetricPadding(const PaddingConfig& padding_config) { - return std::all_of(padding_config.dimensions().begin(), - padding_config.dimensions().end(), - [](const PaddingConfig::PaddingConfigDimension& dim) { - return dim.edge_padding_low() == dim.edge_padding_high(); - }); + return absl::c_all_of(padding_config.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.edge_padding_low() == + dim.edge_padding_high(); + }); } bool HasNegativePadding(const Window& window) { - return std::any_of(window.dimensions().begin(), window.dimensions().end(), - [](const WindowDimension& dim) { - return dim.padding_low() < 0 || dim.padding_high() < 0; - }); + return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) { + return dim.padding_low() < 0 || dim.padding_high() < 0; + }); } bool HasBaseDilation(const Window& window) { @@ -190,10 +189,9 @@ bool AllOrNoneReversed(const Window& window) { return true; } bool reversed = window.dimensions()[0].window_reversal(); - return std::all_of(window.dimensions().begin(), window.dimensions().end(), - [&](const WindowDimension& dim) { - return dim.window_reversal() == reversed; - }); + return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) { + return dim.window_reversal() == reversed; + }); } bool HasDilation(const Window& window) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 0e8fa73f8170addfa5061b33f3d6882a13890bce..e2d7b6ef4666c533951960fd3dcf6869ec2b52c5 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -230,7 +230,11 @@ message DebugOptions { // Enable fast math with eigen in the HLO evaluator. bool xla_hlo_evaluator_use_fast_path = 106; - // Next id: 107 + // Temporary option to allow support for both the R1 and the scalar index + // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing. + bool xla_allow_scalar_index_dynamic_ops = 107; + + // Next id: 108 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 67f475846e5f16060c1080759b0acb4216c4e72b..78d093ec4a5db49002bd987aff1c6d5ca2a3a0c6 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -62,7 +62,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/stream_executor:stream_executor_headers_lib", + "//tensorflow/stream_executor:stream_executor_headers", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 2ccdf0f02d840600d5e0649c4805e3672d4a1286..2ee1a6cd1aebcdbd65892b33e5044489070ab5c4 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -215,11 +215,6 @@ XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default; void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XRTReleaseCompilationRefOp::Compute"; - const Tensor& key_tensor = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(key_tensor.shape()), - errors::Internal("computation key should be a string scalar")); - int64 uid = key_tensor.scalar()(); - ResourceMgr* rm; OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm)); @@ -230,9 +225,13 @@ void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { kXRTCompilationCacheResourceName, &cache)); core::ScopedUnref cache_unref(cache); - OP_REQUIRES_OK(ctx, cache->Release(uid)); - - VLOG(2) << "Released computation handle " << uid; + const Tensor& keys_tensor = ctx->input(0); + auto flat_keys = keys_tensor.flat(); + for (int64 i = 0; i < flat_keys.size(); ++i) { + int64 key = flat_keys(i); + OP_REQUIRES_OK(ctx, cache->Release(key)); + VLOG(2) << "Released computation handle " << key; + } } } // namespace diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 2e2f3ff116a7b331df8dbd58a9fe40096f524140..c8def16bbc05bbc41ef2faf98a01c1d5d758890f 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -453,17 +453,17 @@ class XRTReleaseAllocationOp : public OpKernel { void Compute(OpKernelContext* ctx) override { VLOG(1) << "XRTReleaseAllocationOp::Compute"; - const Tensor& allocation_handle = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_handle.shape()), - errors::Internal("handle input should be an int64 scalar")); - int64 key = allocation_handle.scalar()(); - ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(rm, key)); - - VLOG(2) << "Released allocation handle " << key; + const Tensor& allocation_handle = ctx->input(0); + auto flat_keys = allocation_handle.flat(); + for (int64 i = 0; i < flat_keys.size(); ++i) { + int64 key = flat_keys(i); + OP_REQUIRES_OK(ctx, + XRTTupleAllocation::DeleteFromResourceManager(rm, key)); + VLOG(2) << "Released allocation handle " << key; + } } }; diff --git a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc index 7b3b50c69559f6003a108fdf6a1325dbdbaa80a6..9dd964e5467cd855d67764a512e95a6a18f482e1 100644 --- a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc @@ -44,10 +44,10 @@ REGISTER_OP("XRTReleaseCompilationHandle") .SetShapeFn(tensorflow::shape_inference::NoOutputs) .Doc( R"( -Discards a computation from the compilation cache. The handle cannot be -subsequently used. +Discards one or more computation handles from the compilation cache. +The handle(s) cannot be subsequently used. -'handle' is an id returned from a XRTCompile Op. +'handle' is an ID (or vector of IDs) returned from a XRTCompile Op. )"); } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index fe6bee0dacf5dc2050613fc9ad34d3235b5a7b63..db58f0797df67774f4b74f77beef9a595342fe3f 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -127,10 +127,11 @@ REGISTER_OP("XRTReleaseAllocationHandle") .SetShapeFn(tensorflow::shape_inference::NoOutputs) .Doc( R"( -Discards an allocation from device memory. The handle cannot be subsequently +Discards one or more device memory handles. The handle(s) cannot be subsequently used. -'handle' is the id returned from the Op that produced the on-device allocation. +'handle' is the ID (or a vector of IDs) returned from the Op that produced the +on-device allocation. )"); REGISTER_OP("XRTReleaseAllAllocations") diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 5f8121703e108f26b048feb7a0412a282f52892c..be0c4b9392258ac3bb0066d0e704ad403cdc211b 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -258,8 +258,102 @@ TEST(RawApiTest, AllocAndRewrite) { EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); - auto release = - ops::XRTReleaseAllocationHandle(root, Input(allocation_handle)); + Tensor release_tensor(DT_INT64, TensorShape({1})); + release_tensor.flat()(0) = allocation_handle; + + auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + +TEST(RawApiTest, AllocReleaseMany) { + xrt::XLAAllocation alloc1; + *alloc1.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + xrt::XLAAllocation alloc2; + *alloc2.mutable_value() = + xla::LiteralUtil::CreateR2({{6, 7}, {4, 5}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value1 = + ops::Const(root.WithDevice("/device:CPU:0"), alloc1.SerializeAsString()); + auto value2 = + ops::Const(root.WithDevice("/device:CPU:0"), alloc2.SerializeAsString()); + auto handle1 = ops::XRTAllocate(root, value1); + auto handle2 = ops::XRTAllocate(root, value2); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 allocation_handle1 = outputs[0].scalar()(); + int64 allocation_handle2 = outputs[1].scalar()(); + + Tensor release_tensor(DT_INT64, TensorShape({2})); + release_tensor.flat()(0) = allocation_handle1; + release_tensor.flat()(1) = allocation_handle2; + + auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + +TEST(RawApiTest, CompileAndReleaseMany) { + xrt::XLAComputation c1; + auto config1 = c1.mutable_config(); + auto shapes1 = config1->mutable_program_shape(); + *shapes1->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes1->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes1->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + StoreComputationSnapshot(AddAndScale(), c1.mutable_hlo_snapshot()); + + xrt::XLAComputation c2; + auto config2 = c2.mutable_config(); + auto shapes2 = config2->mutable_program_shape(); + *shapes2->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes2->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes2->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) + .ToProto(); + StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(false); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation1 = + ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString()); + auto c_handle1 = ops::XRTCompile(root, computation1); + auto computation2 = + ops::Const(root.WithDevice("/device:CPU:0"), c2.SerializeAsString()); + auto c_handle2 = ops::XRTCompile(root, computation2); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 compilation_handle1 = outputs[0].scalar()(); + int64 compilation_handle2 = outputs[1].scalar()(); + + Tensor release_tensor(DT_INT64, TensorShape({2})); + release_tensor.flat()(0) = compilation_handle1; + release_tensor.flat()(1) = compilation_handle2; + + auto release = ops::XRTReleaseCompilationHandle(root, release_tensor); + outputs.clear(); TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, &outputs)); } @@ -862,6 +956,7 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xrt::XRTExecutionConfig e; e.set_release_input_handles(true); e.set_release_compilation_handle(true); + e.set_return_exploded_tuple(true); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); auto e_config = diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc index d1405eae468492748ae88d842334a922dce272c6..8bf0f28d2233d9e7593365bc42187e327a1c4ac4 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.cc +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc @@ -273,6 +273,8 @@ Status XRTCompilationCache::Lookup( return Status::OK(); } -string XRTCompilationCache::DebugString() { return "XRTCompilationCache"; } +string XRTCompilationCache::DebugString() const { + return "XRTCompilationCache"; +} } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h index c43d0fc47873abdc82ee937c155bebc346a05f17..7398e847d8b744f947adb03e1bcfd5c0a5b2cc55 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.h +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.h @@ -118,7 +118,7 @@ class XRTCompilationCache : public ResourceBase { // EntryRef holding the program is returned in entry. Status Lookup(int64 uid, std::unique_ptr* entry); - string DebugString() override; + string DebugString() const override; private: // An entry in the compilation cache. The entry is deleted once it has been diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 3e3d5024124e13b87eed6f79596d50cd64325914..4aac37737eaf07dc622fdbfbfbc0775b3e2dafee 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -172,7 +172,7 @@ class XRTTupleAllocation : public ResourceBase { // ownership of the device memory is transferred to the result. xla::ShapeTree ToDeviceMemoryTree(bool release); - string DebugString() override { return "XLA allocation handle"; } + string DebugString() const override { return "XLA allocation handle"; } private: // Creates a new handle with (tuple) shape. diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 832db0f4ab46911e067d17b4a125706c276cf798..307bb6eca35817c666038010192c03c4327d65c9 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -63,7 +63,6 @@ py_library( "//tensorflow/contrib/libsvm", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", - "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 4f1a2a5693235183c8f486817b82c8c81fa389ec..af59686120593283e5dadd5c87c04d2ad24ff8a0 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -91,7 +91,6 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager -from tensorflow.contrib.lite.python import lite from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2 from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.recurrent.python import recurrent_api as recurrent diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb index 44532cb078f9bd1578172f8a7d8a4b55cd21a7cb..831c613f2c8c9a4fcc2cb9d313077fe79ee96fd7 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb @@ -186,8 +186,8 @@ "\n", " def __init__(self):\n", " super(RnnColorbot, self).__init__()\n", - " self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", - " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", + " self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256, dtype=tf.float32)\n", + " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128, dtype=tf.float32)\n", " self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", "\n", " def _rnn_layer(self, chars, cell, batch_size, training):\n", @@ -241,7 +241,7 @@ " seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n", "\n", " # Grab just the end-of-sequence from each output.\n", - " indices = (length - 1, range(batch_size))\n", + " indices = (length - 1, list(range(batch_size)))\n", " indices = tf.stack(indices, 1)\n", " sequence_ends = tf.gather_nd(seq, indices)\n", " return self.relu_layer(sequence_ends)\n", diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h index 4652021fecabfa11fa6a8754dc884d89e151b590..e3b4535bac4a01a1277290e0d1ea6d3c7613731c 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -42,7 +42,7 @@ class BigtableClientResource : public ResourceBase { return client_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("BigtableClientResource(project_id: ", project_id_, ", instance_id: ", instance_id_, ")"); } @@ -67,7 +67,7 @@ class BigtableTableResource : public ResourceBase { ::google::cloud::bigtable::noex::Table& table() { return table_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat( "BigtableTableResource(client: ", client_->DebugString(), ", table: ", table_name_, ")"); diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index 3fe71a2ea730cc9b60b2e2088a0d80a08b38d1a9..c04d5578eef0aadfc918c2de734723fe05c5cfee 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" -#include "google/bigtable/v2/data.pb.h" +#include "external/com_github_googleapis_googleapis/google/bigtable/v2/data.pb.h" #include "google/protobuf/wrappers.pb.h" #include "re2/re2.h" #include "tensorflow/core/lib/strings/stringprintf.h" diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index ee052ac60387d8f993e4942dd7dff39e191dd3a4..47d910d42a27db4b857eeb12209dfbb429dd1be2 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -487,8 +487,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper_0 <= 0.98) self.assertTrue(frac_below_upper_1 >= 0.92) self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.92) - self.assertTrue(frac_both_below_upper <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.91) + self.assertTrue(frac_both_below_upper <= 0.99) train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( two_dimension=True) @@ -516,8 +516,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_above_lower_0 <= 0.98) self.assertTrue(frac_above_lower_1 >= 0.92) self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.92) - self.assertTrue(frac_both_above_lower <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.91) + self.assertTrue(frac_both_above_lower <= 0.99) class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -806,8 +806,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper_0 <= 0.98) self.assertTrue(frac_below_upper_1 >= 0.92) self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.92) - self.assertTrue(frac_both_below_upper <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.91) + self.assertTrue(frac_both_below_upper <= 0.99) train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( two_dimension=True) @@ -835,8 +835,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): self.assertTrue(frac_above_lower_0 <= 0.98) self.assertTrue(frac_above_lower_1 >= 0.92) self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.92) - self.assertTrue(frac_both_above_lower <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.91) + self.assertTrue(frac_both_above_lower <= 0.99) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index e446c411a8d5075563b8f8b912b29df310e16c8c..6faf6963011b698a3b233329d87471da7608e44a 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -96,7 +96,7 @@ class StatsAccumulatorResource : public boosted_trees::StampedResource { TensorShapeUtils::IsScalar(hessian_shape)); } - string DebugString() override { + string DebugString() const override { return strings::StrCat("StatsAccumulatorResource[size=", values_.size(), "]"); } diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index fca22c71a83459cb290eaebcf107cf1c14c222b7..c3685b54e201f73039f6623443c67ba2b217a51e 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -62,8 +62,8 @@ class TreeEnsembleVariableSavable(saver.BaseSaverBuilder.SaveableObject): saver.BaseSaverBuilder.SaveSpec(ensemble_config, slice_spec, name + "_config"), ] - super(TreeEnsembleVariableSavable, - self).__init__(tree_ensemble_handle, specs, name) + super(TreeEnsembleVariableSavable, self).__init__(tree_ensemble_handle, + specs, name) self._tree_ensemble_handle = tree_ensemble_handle self._create_op = create_op @@ -115,7 +115,7 @@ class TreeEnsembleVariable(tracking.TrackableResource): def _gather_saveables_for_checkpoint(self): return { - "tree_ensemble_variable": + self.resource_handle.op.name + "/tree_ensemble_variable": functools.partial( TreeEnsembleVariableSavable, tree_ensemble_handle=self.resource_handle, @@ -131,8 +131,8 @@ def tree_ensemble_variable(stamp_token, Args: stamp_token: The initial stamp token value for the ensemble resource. - tree_ensemble_config: A `Tensor` of type `string`. - Serialized proto of the tree ensemble. + tree_ensemble_config: A `Tensor` of type `string`. Serialized proto of the + tree ensemble. name: A name for the ensemble variable. container: An optional `string`. Defaults to `""`. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index a5951fb7377d48748f5eb578c034176517df7749..e78ec476ab3b43e5eb56a2502008bb8020ae97e0 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -566,9 +566,10 @@ class GradientBoostedDecisionTreeModel(object): # Determine if ensemble is colocated with the inputs. if self._ensemble_handle.device != input_deps[0].device: # Create a local ensemble and get its local stamp. - with ops.name_scope("local_ensemble", "TreeEnsembleVariable") as name: + with ops.name_scope("local_ensemble", "TreeEnsembleVariable"): local_ensemble_handle = ( - gen_model_ops.decision_tree_ensemble_resource_handle_op(name=name)) + gen_model_ops.decision_tree_ensemble_resource_handle_op( + self._ensemble_handle.op.name + "/local_ensemble")) create_op = gen_model_ops.create_tree_ensemble_variable( local_ensemble_handle, stamp_token=-1, tree_ensemble_config="") with ops.control_dependencies([create_op]): diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 94aeb2c7bb48c6eddb6c7894f8bf6f1567470113..0fe57c0a4e8375cc7ec7aca9553bded87e238b33 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -34,7 +34,7 @@ class DecisionTreeEnsembleResource : public StampedResource { protobuf::Arena::CreateMessage< boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {} - string DebugString() override { + string DebugString() const override { return strings::StrCat("GTFlowDecisionTreeEnsemble[size=", decision_tree_ensemble_->trees_size(), "]"); } diff --git a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h index fdaaae7f472c8f564ab45a8366d3746cbf1158ee..574e3065e7f46049815897ef73e44d33f0d23f0f 100644 --- a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h @@ -43,7 +43,7 @@ class QuantileStreamResource : public StampedResource { set_stamp(stamp_token); } - string DebugString() override { return "QuantileStreamResource"; } + string DebugString() const override { return "QuantileStreamResource"; } tensorflow::mutex* mutex() { return &mu_; } diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index ada41687261ab63286933d01da4e286173042e0c..4e529322c7c76797938468b405cd175609dc0a73 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -2,7 +2,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( name = "checkpoint", @@ -27,17 +27,17 @@ py_library( ], ) -py_test( +tf_py_test( name = "containers_test", srcs = ["containers_test.py"], - deps = [ + additional_deps = [ ":containers", + "@six_archive//:six", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", "//tensorflow/python/training/checkpointable:base", "//tensorflow/python/training/checkpointable:util", - "@six_archive//:six", ], ) @@ -53,18 +53,18 @@ py_library( ], ) -py_test( +tf_py_test( name = "python_state_test", srcs = ["python_state_test.py"], - deps = [ + additional_deps = [ ":python_state", + "//third_party/py/numpy", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:session", "//tensorflow/python:variables", "//tensorflow/python/eager:test", "//tensorflow/python/training/checkpointable:util", - "//third_party/py/numpy", ], ) @@ -80,10 +80,10 @@ py_library( ], ) -py_test( +tf_py_test( name = "split_dependency_test", srcs = ["split_dependency_test.py"], - deps = [ + additional_deps = [ ":split_dependency", "//tensorflow/python:array_ops", "//tensorflow/python:framework_test_lib", @@ -106,10 +106,10 @@ py_library( ], ) -py_test( +tf_py_test( name = "visualize_test", srcs = ["visualize_test.py"], - deps = [ + additional_deps = [ ":visualize", "//tensorflow/python:constant_op", "//tensorflow/python:resource_variable_ops", diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 1311063ec023bdaa2588d6f1c826bf900f7dea09..20f8c2b2453a58fdbe5a3587fa6687debd9c06d3 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -27,7 +27,6 @@ tf_kernel_library( deps = [ ":bigquery_table_accessor", ":bigquery_table_partition_proto_cc", - "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:reader_base", @@ -79,7 +78,6 @@ tf_kernel_library( srcs = ["gcs_config_ops.cc"], visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/contrib/cloud:gcs_config_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/platform/cloud:curl_http_request", diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index 46a193971c5084523d432065f265fa7a9909f595..6c6a5df7f76723800740a81ccdcb137a0ec33846 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -31,17 +31,17 @@ if (systemlib_ABSEIL_CPP) message(STATUS " abseil_cpp includes: ${ABSEIL_CPP_INCLUDE_DIR}") message(STATUS " abseil_cpp libraries: ${ABSEIL_CPP_LIBRARIES}") - add_custom_target(abseil_cpp) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) + add_custom_target(abseil_cpp_build) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) else (systemlib_ABSEIL_CPP) include (ExternalProject) - set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp) - set(abseil_cpp_URL https://github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz) - set(abseil_cpp_HASH SHA256=84043ed402d2a2a6ba4cdddb7e85118b1158fd81fe4ac3a14adc343d054c1e2e) - set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp-build) + set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) + set(abseil_cpp_URL https://github.com/abseil/abseil-cpp.git) + set(abseil_cpp_TAG master) + set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") @@ -49,8 +49,11 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/Release/absl_base.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_malloc_internal.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/numeric/Release/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/Release/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/time/Release/absl_time.lib ${abseil_cpp_BUILD}/absl/types/Release/absl_bad_optional_access.lib) else() set(abseil_cpp_STATIC_LIBRARIES @@ -62,6 +65,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/numeric/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/time/absl_time.lib ${abseil_cpp_BUILD}/absl/types/absl_bad_optional_access.lib) endif() else() @@ -74,15 +78,18 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/numeric/libabsl_int128.a ${abseil_cpp_BUILD}/absl/strings/libabsl_strings.a ${abseil_cpp_BUILD}/absl/strings/libstr_format_internal.a + ${abseil_cpp_BUILD}/absl/time/libabsl_time.a ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) endif() - ExternalProject_Add(abseil_cpp + ExternalProject_Add(abseil_cpp_build PREFIX abseil_cpp - URL ${abseil_cpp_URL} - URL_HASH ${abseil_cpp_HASH} + GIT_REPOSITORY ${abseil_cpp_URL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${abseil_cpp_STATIC_LIBRARIES} + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release + COMMAND ${CMAKE_COMMAND} --build . --config Release INSTALL_COMMAND "" CMAKE_CACHE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} @@ -91,8 +98,10 @@ else (systemlib_ABSEIL_CPP) ) include_directories(${abseil_cpp_INCLUDE_DIR}) + message(STATUS ${abseil_cpp_INCLUDE_DIR}) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${abseil_cpp_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) endif (systemlib_ABSEIL_CPP) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 479609458c64f7c7bd7b3ce6b23aceaa3db17f21..b15143bfc1cd787b156c9d6dd724a17730f0f8fb 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 1.20.1) +set(nsync_TAG 1.20.2) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 96160568fa79291a7b391761373e1eaf0f70974e..21ae9a08a6bb8f71e5935ddde2d7bb3ed0cd8bbc 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -1,6 +1,9 @@ # python_sanity_test.py will complain about invalid or missing entries # problematic entries can be commented for temporary whitelisting tensorflow +tensorflow/compiler +tensorflow/compiler/xla +tensorflow/compiler/xla/service tensorflow/core tensorflow/core/example tensorflow/core/framework diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index d7b2a1339e047aba0a9424a53a63726805e89721..9ae5831f471414d3b610870f9e5f08265f7cb22a 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -302,8 +302,8 @@ file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.h" + "${tensorflow_source_dir}/tensorflow/core/summary/*.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/*.h" "${tensorflow_source_dir}/public/*.h" ) @@ -317,14 +317,14 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/util/*test*.h" "${tensorflow_source_dir}/tensorflow/core/util/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/util/*main.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/loader.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/vacuum.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/*test*.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/loader.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/vacuum.cc" ) # TODO(jart): Why doesn't this work? # set_source_files_properties( -# ${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/snapfn.cc +# ${tensorflow_source_dir}/tensorflow/core/lib/db/snapfn.cc # PROPERTIES COMPILE_FLAGS -DSQLITE_OMIT_LOAD_EXTENSION) list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs}) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 8faccf8d55902e6701ebb4ce534b84705304fd5f..1fe8795ddf00232eba5a60a130e0845a6f6a8e17 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -802,6 +802,7 @@ add_custom_command( # tensorflow/__init__.py depends on files generated in this step. So, remove it while # this step is running since the files aren't there yet. COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py # Run create_python_api.py to generate API init files. COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index f867cd15b67dbd43650d8012b4299845af7200a8..2aa5d77b108d0f000550bd4d57fc4a922e7594e6 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import function_utils +from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -76,7 +77,12 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin All `Operation`s returned from `computation` will be executed when evaluating any of the returned output tensors. - inputs: A list of input tensors or `None` (equivalent to an empty list). + inputs: A list of inputs or `None` (equivalent to an empty list). Each input + can be a nested structure containing values that are convertible to + tensors. Note that passing an N-dimension list of compatible values will + result in a N-dimention list of scalar tensors rather than a single Rank-N + tensors. If you need different behavior, convert part of inputs to tensors + with `tf.convert_to_tensor`. Returns: A list of output tensors. @@ -260,17 +266,10 @@ def _compile_internal(computation, inputs=None): if not isinstance(inputs, collections.Sequence): raise TypeError('inputs must be a list') + # Flatten inputs. + flat_inputs = nest.flatten(inputs) # Converts inputs to Tensors. - inputs = [ops.convert_to_tensor(x) for x in inputs] - input_arity = len(inputs) - - arg_error = check_function_argument_count( - computation, input_arity, infeed_queue=None) - if arg_error is not None: - raise TypeError( - 'Supplied computation cannot be called with the specified inputs. You ' - 'specified %d inputs: %s, but the computation needs %s' % - (input_arity, str([i.name for i in inputs]), arg_error)) + flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] cluster_name = ops.get_default_graph().unique_name('cluster') pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') @@ -280,11 +279,15 @@ def _compile_internal(computation, inputs=None): # Add identity ops so even unused inputs are 'consumed' by the # computation. - computation_inputs = [ + flat_inputs = [ array_ops.identity(x, name='input_{}'.format(i)) - for i, x in enumerate(inputs) + for i, x in enumerate(flat_inputs) ] + # Re-pack flat_inputs in same structure as 'inputs'. + computation_inputs = nest.pack_sequence_as( + structure=inputs, flat_sequence=flat_inputs) + # Only resource variables work inside an XLA computation, so turn on # resource variables for the computation. vscope = variable_scope.get_variable_scope() diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 8a8dc159ade6f2a4a9b5ec29055ea4848492b29f..dbcaf8185fb7a9d2bcf22376439c0ebd49accb1a 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -43,28 +43,19 @@ the workers. Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` with [tf.keras] (https://www.tensorflow.org/guide/keras). -Take a very simple model consisting of a single layer: +Let's define a simple input dataset for training this model. Note that currently we require using +[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) +with `DistributionStrategy`. ```python import tensorflow as tf from tensorflow import keras -inputs = tf.keras.layers.Input(shape=(1,)) -predictions = tf.keras.layers.Dense(1)(inputs) -model = tf.keras.models.Model(inputs=inputs, outputs=predictions) -``` - -Let's also define a simple input dataset for training this model. Note that currently we require using -[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) -with `DistributionStrategy`. - -```python features = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10) labels = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10) train_dataset = tf.data.Dataset.zip((features, labels)) ``` - To distribute this Keras model on multiple GPUs using `MirroredStrategy` we first instantiate a `MirroredStrategy` object. @@ -72,14 +63,17 @@ first instantiate a `MirroredStrategy` object. distribution = tf.contrib.distribute.MirroredStrategy() ``` -We then compile the Keras model and pass the `MirroredStrategy` object in the -`distribute` argument (apart from other usual arguments like `loss` and -`optimizer`). +Take a very simple model consisting of a single layer. We need to create and compile +the model under the distribution strategy scope. ```python -model.compile(loss='mean_squared_error', - optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2), - distribute=distribution) +with distribution.scope(): + inputs = tf.keras.layers.Input(shape=(1,)) + predictions = tf.keras.layers.Dense(1)(inputs) + model = tf.keras.models.Model(inputs=inputs, outputs=predictions) + + model.compile(loss='mean_squared_error', + optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2)) ``` To train the model we call Keras `fit` API using the input dataset that we diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 8ec73654e30e4967f318c558ba94301e84a206e4..1bcc453a7e85b4e6737724cb5bdcd153cdd8c8ea 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -30,6 +30,7 @@ from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * +from tensorflow.contrib.distribute.python.tpu_strategy import initialize_tpu_system from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.distribute.cross_device_ops import * from tensorflow.python.distribute.distribute_config import DistributeConfig @@ -58,11 +59,14 @@ _allowed_symbols = [ 'StandardSingleLossStep', 'ReplicaContext', 'TPUStrategy', + 'initialize_tpu_system', 'get_cross_replica_context', 'get_distribution_strategy', 'get_loss_reduction', 'get_replica_context', + 'get_strategy', 'has_distribution_strategy', + 'has_strategy', 'in_cross_replica_context', 'require_replica_context', 'run_standard_tensorflow_server', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 2d6a08df9a266095d3a9619e692ae9b96af93ceb..d4758d7518f4e209d58d559723c67e90473d8a79 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -1,8 +1,8 @@ # Implementation of a prototype TF distributed computation library. +load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") package( default_visibility = [ @@ -23,17 +23,14 @@ cuda_py_test( additional_deps = [ ":combinations", ":mirrored_strategy", - ":multi_worker_test_base", "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:device_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", @@ -45,14 +42,36 @@ cuda_py_test( ], ) +cuda_py_test( + name = "input_lib_test", + srcs = ["input_lib_test.py"], + additional_deps = [ + ":combinations", + ":mirrored_strategy", + ":multi_worker_test_base", + "@absl_py//absl/testing:parameterized", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:values", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "no_pip", + ], +) + py_library( name = "mirrored_strategy", srcs = ["mirrored_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:mirrored_strategy", - "//tensorflow/python/distribute:values", ], ) @@ -61,18 +80,10 @@ py_library( srcs = ["parameter_server_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":mirrored_strategy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:parameter_server_strategy", "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", ], ) @@ -119,6 +130,8 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:numpy_dataset", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", @@ -139,7 +152,9 @@ py_library( "//tensorflow/python:training", "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:cross_device_utils", + "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:numpy_dataset", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", ], @@ -165,6 +180,7 @@ py_library( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//third_party/py/numpy", ], ) @@ -289,6 +305,8 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:numpy_dataset", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", ], @@ -660,7 +678,9 @@ cuda_py_test( additional_deps = [ ":keras_correctness_test_lib", ], - shard_count = 16, + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 19, tags = [ "multi_and_single_gpu", "no_oss", # TODO(b/117919883): Fix python error. diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 12197c3d0dedee23d12732b8d4398f43bfc61caa..eee07543251321ae0c9eef57851431cf97c65643 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -26,9 +26,12 @@ from tensorflow.python.distribute import cross_device_ops as cross_device_ops_li from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import values from tensorflow.python.eager import context +from tensorflow.python.eager import tape from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops @@ -85,9 +88,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): else: local_devices = ("/device:CPU:0",) self._worker_device = device_util.canonicalize("/device:CPU:0") + self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) + # TODO(yuefengz): remove num_gpus_per_worker from CollectiveAllReduce. self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, num_gpus_per_worker=num_gpus_per_worker, @@ -120,6 +125,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): task_id) self._worker_device = "/job:%s/task:%d" % (task_type, task_id) + self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) if num_gpus_per_worker: local_devices = tuple( "%s/device:GPU:%d" % (self._worker_device, i) @@ -130,7 +136,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) - self._input_workers = values.InputWorkers( + self._input_workers = input_lib.InputWorkers( self._device_map, [(self._worker_device, self.worker_devices)]) self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, @@ -156,19 +162,23 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. + elif isinstance(colocate_with, numpy_dataset.SingleDevice): + with ops.device(colocate_with.device): + return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device - group_size = device_map.num_replicas_in_graph * self._num_workers - group_key = self._collective_keys.get_group_key(self.worker_devices) def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" - value_list = [] unique_var_name = ops.get_default_graph().unique_name( kwargs["name"], mark_as_used=False).rstrip("/") + # pylint: disable=protected-access collective_instance_key = self._collective_keys.get_instance_key( key_id=unique_var_name) + # Only the first device participles in the broadcast of initial values. + group_key = self._collective_keys.get_group_key([devices[0]]) + group_size = self._num_workers if "initial_value" not in kwargs: raise ValueError("Initial value must be specified.") initial_value = kwargs["initial_value"] @@ -177,9 +187,33 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): else: initial_value_fn = lambda: initial_value + value_list = [] for i, d in enumerate(devices): - with ops.device(d): - if i > 0: + with ops.init_scope(), ops.device(d): + if i == 0: + # The initial value fn makes sure variables all initialized to + # same values. The first device of the chief worker will send their + # variable values to other workers. + def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring + with ops.device(device): + initial_value = initial_value_fn() + assert not callable(initial_value) + initial_value = ops.convert_to_tensor(initial_value) + + assert index == 0, index + if self._num_workers > 1: + if self._is_chief: + bcast_send = collective_ops.broadcast_send( + initial_value, initial_value.shape, initial_value.dtype, + group_size, group_key, collective_instance_key) + with ops.control_dependencies([bcast_send]): + return array_ops.identity(initial_value) + else: + return collective_ops.broadcast_recv( + initial_value.shape, initial_value.dtype, group_size, + group_key, collective_instance_key) + return initial_value + else: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to @@ -187,30 +221,22 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) - # The initial value fn makes sure variables all initialized to - # same values. The first device of the chief worker will send their - # variable values to other devices and other workers. - def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring - with ops.device(device): - initial_value = initial_value_fn() - assert not callable(initial_value) - initial_value = ops.convert_to_tensor(initial_value) - - if self._is_chief and index == 0: - bcast_send = collective_ops.broadcast_send( - initial_value, initial_value.shape, initial_value.dtype, - group_size, group_key, collective_instance_key) - with ops.control_dependencies([bcast_send]): - return array_ops.identity(initial_value) - else: - return collective_ops.broadcast_recv( - initial_value.shape, initial_value.dtype, group_size, - group_key, collective_instance_key) + # Variables on non-first replica get initial values from the + # variables created on the first device of each worker. + def _overridden_initial_value_fn(device=d, index=i): + assert index > 0 + with ops.device(device): + if context.executing_eagerly(): + return array_ops.identity(value_list[0].value()) + else: + return array_ops.identity(value_list[0].initial_value) kwargs["initial_value"] = _overridden_initial_value_fn - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - v = next_creator(*args, **kwargs) + # Don't record operations (e.g. other variable reads) during + # variable creation. + with tape.stop_recording(): + v = next_creator(*args, **kwargs) if i == 0: actual_var_name = v.name.split(":")[0] @@ -229,13 +255,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): """Distributes the dataset to each local GPU.""" # TODO(yuefengz): shard the dataset. worker_index = 0 - return values.PerReplicaDataset( + return input_lib.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._input_workers, worker_index, prefetch_on_device=True) def _make_dataset_iterator(self, dataset): - return values.DatasetIterator(dataset, self._input_workers, - self._num_replicas_in_sync) + return input_lib.DatasetIterator(dataset, self._input_workers, + self._num_replicas_in_sync) def _make_input_fn_iterator( self, @@ -252,7 +278,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) - return values.InputFunctionIterator( + return input_lib.InputFunctionIterator( input_fn, self._input_workers, [input_context]) def _configure(self, @@ -346,4 +372,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): - return False + """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. + + `distribute_dataset` and `make_input_fn_iterator` assume per-replica + batching. + + Returns: + Boolean. + """ + return True diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 0fb672dded7624e798592d2f5c01945aa830021e..62ff4b178ea8a69c603d0e213de9efbada41eddb 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -192,6 +192,7 @@ class CollectiveAllReduceStrategyTestBase( image = random_ops.random_uniform([2, 28, 28]) label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32) logits = model(image, training=True) + # TODO(yuefengz): make loss a callable for eager mode. loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits) optimizer = adam.AdamOptimizer(learning_rate=1e-4) train_op = optimizer.minimize(loss, @@ -397,28 +398,38 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) -class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, - strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class LocalCollectiveAllReduceStrategy( + CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): - def testMinimizeLossGraph(self, num_gpus=2): + @combinations.generate( + combinations.combine( + mode=['graph', 'eager'], num_gpus=[2, 4], required_gpus=2)) + def testMinimizeLoss(self, num_gpus): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - self._test_minimize_loss_graph(None, None, num_gpus) + if context.executing_eagerly(): + strategy, _, _ = self._get_test_object(None, None, num_gpus) + self._test_minimize_loss_eager(strategy) + else: + self._test_minimize_loss_graph(None, None, num_gpus) - def testComplexModel(self, num_gpus=2): - # Collective ops doesn't support strategy with one device. + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[2, 4], required_gpus=2)) + def testComplexModel(self, num_gpus): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') self._test_complex_model(None, None, num_gpus) - def testMakeInputFnIterator(self, num_gpus=2): - # Collective ops doesn't support strategy with one device. - if context.num_gpus() < num_gpus: - self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(10) - expected_values = [[i, i+1] for i in range(0, 10, 2)] + @combinations.generate( + combinations.combine(mode=['graph', 'eager'], required_gpus=2)) + def testMakeInputFnIterator(self): + num_gpus = 2 + dataset_fn = lambda: dataset_ops.Dataset.range(5 * num_gpus) + expected_values = [range(i, i + num_gpus) for i in range(0, 10, num_gpus)] input_fn = self._input_fn_to_test_input_context( dataset_fn, @@ -428,6 +439,49 @@ class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, self._test_input_fn_iterator(None, None, num_gpus, input_fn, expected_values) + def testAllReduceSum(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean_gradient_tape(distribution) + + def testNumpyIterator(self): + num_gpus = 2 + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + strategy, _, _ = self._get_test_object(None, None, num_gpus) + self._test_numpy_iterator(strategy) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 4a934953ad2d4c6ecbe2bde2333a49bf8fd72821..f6c4291659b493b59b102b7ca83454f804898772 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -46,7 +46,7 @@ import unittest from absl.testing import parameterized import six -from tensorflow.contrib.cluster_resolver import TPUClusterResolver +from tensorflow.contrib import cluster_resolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib @@ -321,6 +321,15 @@ class NamedDistribution(object): return self._required_tpu +def _get_tpu_strategy_creator(steps_per_run): + def _create_tpu_strategy(): + resolver = cluster_resolver.TPUClusterResolver("") + tpu_lib.initialize_tpu_system(resolver) + strategy = tpu_lib.TPUStrategy(resolver, steps_per_run=steps_per_run) + return strategy + return _create_tpu_strategy + + # pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", @@ -330,13 +339,12 @@ one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) tpu_strategy = NamedDistribution( - "TPU", lambda: tpu_lib.TPUStrategy( - TPUClusterResolver(""), steps_per_run=2), + "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True) tpu_strategy_one_step = NamedDistribution( - "TPUOneStep", lambda: tpu_lib.TPUStrategy( - TPUClusterResolver(""), steps_per_run=1), + "TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True) + mirrored_strategy_with_one_cpu = NamedDistribution( "Mirrored1CPU", lambda: mirrored_lib.MirroredStrategy(["/cpu:0"])) diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index 60fda996642464135fe1fb8c314bcf7f04d19362..1ce91ecaf22a80a53124c8f00fac05c6b4711ed9 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -109,22 +109,21 @@ def main(_): tf.enable_eager_execution() train_ds, eval_ds, input_shape = get_input_datasets() - model = get_model(input_shape) # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or # the `devices` argument then all the GPUs available on the machine are used. # TODO(priyag): Use `tf.distribute.MirroredStrategy` once available. strategy = mirrored_strategy.MirroredStrategy(['/gpu:0', '/cpu:0']) - optimizer = rmsprop.RMSProp(learning_rate=0.001) - - # Compile the model by passing the distribution strategy object to the - # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed - # based on the strategy instantiated. - model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=optimizer, - metrics=['accuracy'], - distribute=strategy) + # Create and compile the model under Distribution strategy scope. + # `fit`, `evaluate` and `predict` will be distributed based on the strategy + # model was compiled with. + with strategy.scope(): + model = get_model(input_shape) + optimizer = rmsprop.RMSProp(learning_rate=0.001) + model.compile(loss=tf.keras.losses.categorical_crossentropy, + optimizer=optimizer, + metrics=['accuracy']) # Train the model with the train dataset. model.fit(x=train_ds, epochs=20, steps_per_epoch=468) diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f589cd6ad54ea8f33002cb067ef8d83d3d33036a --- /dev/null +++ b/tensorflow/contrib/distribute/python/input_lib_test.py @@ -0,0 +1,480 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the input_lib library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.util import nest + + +class PerReplicaDatasetTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _test_iterator(self, devices, dataset, expected_values): + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map) + per_replica_dataset = input_lib.PerReplicaDataset(dataset, input_workers, 0) + if context.executing_eagerly(): + iterator = per_replica_dataset.make_one_shot_iterator() + else: + iterator = per_replica_dataset.make_initializable_iterator() + self.evaluate([iterator.initializer]) + + for expected_value in expected_values: + next_element = iterator.get_next_as_list() + computed_value = self.evaluate(next_element) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next_as_list() + self.evaluate(next_element) + + @test_util.run_in_graph_and_eager_modes + def testOneDevice(self): + devices = ["/device:CPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleDevices(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testTupleDataset(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testUnevenDatasetBatches(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(devices, dataset, expected_values) + + def testInitializableIterator(self): + with context.graph_mode(): + devices = ["/device:CPU:0"] + # Using random input since that is only allowed with initializable + # iterator. + dataset = dataset_ops.Dataset.from_tensor_slices( + random_ops.random_uniform((10,))) + + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map) + per_replica_dataset = input_lib.PerReplicaDataset( + dataset, input_workers, 0) + iterator = per_replica_dataset.make_initializable_iterator() + + self.evaluate(iterator.initializer) + next_element = iterator.get_next_as_list() + for _ in range(10): + self.evaluate(next_element) + + # Should fail after the input is finished. + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(next_element) + + # After re-initializing the iterator, should be able to iterate again. + self.evaluate(iterator.initializer) + for _ in range(10): + self.evaluate(next_element) + + +class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): + + def _test_iterator(self, sess, iterator, devices, expected_values): + next_element = iterator.get_next() + for r, device in enumerate(devices): + v = values.select_replica(r, next_element) + # The `v` here can be a tuple. + for element in nest.flatten(v): + self.assertTrue(element.device in device) + + for expected_value in expected_values: + t = [values.select_replica(r, next_element) for r in range(len(devices))] + actual = sess.run(t) + self.assertEqual(expected_value, actual) + + with self.assertRaises(errors.OutOfRangeError): + sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) + + def _test_dataset(self, dataset_fn, worker_devices, devices, + expected_values): + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_devices) + multi_worker_dataset = input_lib.MultiWorkerDataset( + dataset_fn, input_workers) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + with self.cached_session() as sess: + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, expected_values) + + def _cpu_devices(self): + worker_devices = ( + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"]) + ) + devices = [ + "/job:worker/replica:0/task:0/device:CPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ] + return worker_devices, devices + + def _cpu_and_one_gpu_devices(self): + worker_devices = ( + ("/job:worker/replica:0/task:0", ( + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + )), + ("/job:worker/replica:0/task:1", ( + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + )) + ) + devices = [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0", + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ] + return worker_devices, devices + + def testDataDistributionOneDevicePerWorker(self): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + self._test_dataset( + dataset_fn, worker_devices, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) + + def testDataDistributionTwoDevicePerWorker(self): + if context.num_gpus() < 1: + self.skipTest("A GPU is not available for this test.") + worker_devices, devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + self._test_dataset( + dataset_fn, worker_devices, devices, + [[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]]) + + def testTupleDataset(self): + worker_devices, devices = self._cpu_devices() + + with context.graph_mode(): + + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(8) + dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i, i**2)] for i in range(8)] + self._test_dataset(dataset_fn, worker_devices, devices, + expected_values) + + def testInitializableIterator(self): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(8) + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_devices) + multi_worker_dataset = input_lib.MultiWorkerDataset( + dataset_fn, input_workers) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + + sess.run(multi_worker_iterator.initializer) + self._test_iterator( + sess, multi_worker_iterator, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) + + # After re-initializing the iterator, should be able to iterate again. + sess.run(multi_worker_iterator.initializer) + self._test_iterator( + sess, multi_worker_iterator, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) + + def testValueErrorForIterator(self): + # Incompatiable arguments. + d1 = "/device:GPU:0" + d2 = "/device:GPU:1" + device_map = values.ReplicaDeviceMap([d1, d2]) + input_workers = input_lib.InputWorkers( + device_map, (("w1", (d1,)), ("w2", (d2,)))) + with self.assertRaises(ValueError): + input_lib.MultiWorkerDataIterator([("w1", None)], input_workers) + + def testDuplicateDevices(self): + _, devices = self._cpu_devices() + devices.append("/job:worker/replica:0/task:0/device:CPU:0") + with self.assertRaises(ValueError): + _ = values.ReplicaDeviceMap(devices) + + +class InputIteratorTestBase(test.TestCase): + + def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) + + if input_type == "input_fn": + input_contexts = [ + distribute_lib.InputContext() for _ in worker_device_pairs] + input_fn = lambda _: dataset_fn() + iterator = input_lib.InputFunctionIterator( + input_fn, input_workers, input_contexts) + else: + iterator = input_lib.DatasetIterator( + dataset_fn(), input_workers, split_batch_by) + + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) + self.assertAllEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_replica(r, next_element) + for r in range(len(devices))]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) + self.assertAllEqual(expected_value, computed_value) + + +class InputIteratorSingleWorkerTest(InputIteratorTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"])) + def testOneDeviceCPU(self, input_type): + worker_device_pairs = [("", ["/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesOneGPUOneCPU(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTupleDataset(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testUnevenDatasetBatches(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["dataset"], + split_batch_by=[None, 2], + required_gpus=1)) + def testBatchSplitting(self, input_type, split_batch_by): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + batch_size = 10 + dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) + + updated_batch_size = ( + batch_size // split_batch_by if split_batch_by else batch_size) + expected_values = [[range(i, i+updated_batch_size), + range(i+updated_batch_size, i+2*updated_batch_size)] + for i in range(0, 100, updated_batch_size*2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, + split_batch_by=split_batch_by) + + +class InputIteratorMultiWorkerTest( + multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, + parameterized.TestCase): + + def _cpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] + + def _cpu_and_one_gpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), + ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ]) + ] + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testOneDevicePerWorker(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesPerWorker(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 1, 0, 1], [2, 3, 2, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testTupleDataset(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(4) + dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + + +class SplitDatasetBatchTest(test.TestCase): + + def testBatchDataset(self): + dataset = dataset_ops.Dataset.range(100).batch(20) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + def testMapAndBatchDataset(self): + dataset = dataset_ops.Dataset.range(100) + dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + def testPrefetchDataset(self): + dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_correctness_test.py b/tensorflow/contrib/distribute/python/keras_correctness_test.py index 3abcdde65b2ecf79a9c7f26dd4c6d1325dd2c5d9..3abdee2c0ed8ebe60876a6d43bb90d8f4b7a5052 100644 --- a/tensorflow/contrib/distribute/python/keras_correctness_test.py +++ b/tensorflow/contrib/distribute/python/keras_correctness_test.py @@ -34,6 +34,8 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_de from tensorflow.python.training import gradient_descent _RANDOM_SEED = 1337 +_EVAL_STEPS = 20 +_GLOBAL_BATCH_SIZE = 64 # Note: Please make sure the tests in this file are also covered in # keras_backward_compat_test for features that are supported with both APIs. @@ -61,12 +63,23 @@ def all_strategy_combinations_with_graph_mode(): def strategy_and_input_combinations(): + def cnn_model_with_batch_norm(**kwargs): + return _create_cnn_model(with_batch_norm=True, **kwargs) + return ( combinations.times( combinations.combine(distribution=all_strategies), combinations.combine(mode=['graph', 'eager'], use_numpy=[True, False], - use_validation_data=[True, False]))) + use_validation_data=[True, False]), + combinations.combine(model_with_data=[ + ModelWithData('dnn', _create_dnn_model, _dnn_training_data), + ModelWithData('cnn', _create_cnn_model, _cnn_training_data), + ModelWithData('cnn_batch_norm', + cnn_model_with_batch_norm, + _cnn_training_data, + with_batch_norm=True), + ]))) class MaybeDistributionScope(object): @@ -87,7 +100,35 @@ class MaybeDistributionScope(object): self._scope = None -def _create_dnn_model(weights=None, distribution=None, compile_model=True): +class ModelWithData(object): + """An object giving a good name in combinations. + + The model_fn must take two arguments: initial_weights and distribution. + """ + + def __init__(self, name, model_fn, data_fn, with_batch_norm=False): + self.name = name + self.model_fn = model_fn + self.data_fn = data_fn + self.with_batch_norm = with_batch_norm + + def __repr__(self): + return self.name + + +def _dnn_training_data(): + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + x_predict = [[1.], [2.], [3.], [4.]] + return x_train, y_train, x_predict + + +def _create_dnn_model(initial_weights=None, distribution=None): with MaybeDistributionScope(distribution): # We add few non-linear layers to make it non-trivial. model = keras.Sequential() @@ -96,17 +137,60 @@ def _create_dnn_model(weights=None, distribution=None, compile_model=True): model.add(keras.layers.Dense(10, activation='relu')) model.add(keras.layers.Dense(1)) - if weights: - model.set_weights(weights) + if initial_weights: + model.set_weights(initial_weights) - if compile_model: - model.compile( - loss=keras.losses.mean_squared_error, - optimizer=gradient_descent_keras.SGD(0.5), - metrics=['mse']) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse']) return model +def _cnn_training_data(count=_GLOBAL_BATCH_SIZE * _EVAL_STEPS, + shape=(28, 28, 3), num_classes=10): + centers = np.random.randn(num_classes, *shape) + + features = [] + labels = [] + for _ in range(count): + label = np.random.randint(0, num_classes, size=1)[0] + offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape)) + offset = offset.reshape(shape) + labels.append(label) + features.append(centers[label] + offset) + + x_train = np.asarray(features, dtype=np.float32) + y_train = np.asarray(labels, dtype=np.float32).reshape((count, 1)) + x_predict = x_train + return x_train, y_train, x_predict + + +def _create_cnn_model(initial_weights=None, distribution=None, + with_batch_norm=False): + with MaybeDistributionScope(distribution): + image = keras.layers.Input(shape=(28, 28, 3), name='image') + c1 = keras.layers.Conv2D( + name='conv1', filters=16, kernel_size=(3, 3), strides=(4, 4))( + image) + if with_batch_norm: + c1 = keras.layers.BatchNormalization(name='bn1')(c1) + c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1) + logits = keras.layers.Dense( + 10, activation='softmax', name='pred')( + keras.layers.Flatten()(c1)) + model = keras.Model(inputs=[image], outputs=[logits]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + return model + + def batch_wrapper(dataset, batch_size, distribution, repeat=None): if repeat: dataset = dataset.repeat(repeat) @@ -133,7 +217,7 @@ def get_correctness_test_inputs(use_numpy, use_validation_data, with_distribution, x_train, y_train, x_predict): """Generates the inputs for correctness check when enable Keras with DS.""" training_epochs = 2 - global_batch_size = 64 + global_batch_size = _GLOBAL_BATCH_SIZE batch_size = get_batch_size(global_batch_size, with_distribution) if use_numpy: @@ -158,6 +242,11 @@ def get_correctness_test_inputs(use_numpy, use_validation_data, 'x': np.array(x_predict, dtype=np.float32), } else: + if len(x_train) < _GLOBAL_BATCH_SIZE * _EVAL_STEPS: + # Currently, we cannot detech the size of a dataset. So, the eval steps is + # hard coded. + raise ValueError('x_train must have at least ' + '_GLOBAL_BATCH_SIZE * _EVAL_STEPS samples') # For dataset inputs, we do not pass batch_size to # keras.fit/evaluate/predict. The batch size is part of the dataset. train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) @@ -183,7 +272,7 @@ def get_correctness_test_inputs(use_numpy, use_validation_data, 'batch_size': None, 'x': x, 'y': None, - 'steps': 20, + 'steps': _EVAL_STEPS, } predict_batch_size = get_batch_size(len(x_predict), with_distribution) @@ -199,8 +288,8 @@ def get_correctness_test_inputs(use_numpy, use_validation_data, def fit_eval_and_predict( - initial_weights, input_fn, distribution=None): - model = _create_dnn_model(initial_weights, distribution) + initial_weights, input_fn, model_fn, distribution=None): + model = model_fn(initial_weights=initial_weights, distribution=distribution) training_inputs, eval_inputs, predict_inputs = input_fn(distribution) result = {} @@ -351,7 +440,8 @@ class TestDistributionStrategyCorrectness(test.TestCase, self.assertEqual(outs[2], 0.) @combinations.generate(strategy_and_input_combinations()) - def test_correctness(self, distribution, use_numpy, use_validation_data): + def test_correctness(self, distribution, use_numpy, use_validation_data, + model_with_data): if self._should_skip_tpu_with_eager(distribution): self.skipTest('TPUStrategy does not support eager mode now.') @@ -366,21 +456,15 @@ class TestDistributionStrategyCorrectness(test.TestCase, np.random.seed(_RANDOM_SEED) random_seed.set_random_seed(_RANDOM_SEED) + model_fn, data_fn = model_with_data.model_fn, model_with_data.data_fn # Train, eval, and predict datasets are created with the same input numpy # arrays. - # TODO(xiejw): Change this back to 10000, once we support final partial - # batch. - num_samples = 9984 - x_train = np.random.rand(num_samples, 1) - y_train = 3 * x_train - x_train = x_train.astype('float32') - y_train = y_train.astype('float32') - x_predict = [[1.], [2.], [3.], [4.]] + x_train, y_train, x_predict = data_fn() # The model is built once and the initial weights are saved. # This is used to initialize the model for both the distribution and # non-distribution run. - model = _create_dnn_model(compile_model=False) + model = model_fn() initial_weights = model.get_weights() def input_fn(dist): @@ -388,11 +472,22 @@ class TestDistributionStrategyCorrectness(test.TestCase, use_numpy, use_validation_data, dist, x_train, y_train, x_predict) results_with_ds = fit_eval_and_predict( - initial_weights, input_fn=input_fn, distribution=distribution) + initial_weights, input_fn=input_fn, model_fn=model_fn, + distribution=distribution) results_without_ds = fit_eval_and_predict( - initial_weights, input_fn=input_fn, distribution=None) - compare_results(results_with_ds, results_without_ds, distribution, - testcase=self) + initial_weights, input_fn=input_fn, model_fn=model_fn, + distribution=None) + + # First, special case, for multi-replica distributed training, batch norm + # is not aggregated globally. So it is expected to have different weights. + if (model_with_data.with_batch_norm and + distribution.num_replicas_in_sync > 1): + with self.assertRaises(AssertionError): + compare_results(results_with_ds, results_without_ds, distribution, + testcase=self) + else: + compare_results(results_with_ds, results_without_ds, distribution, + testcase=self) @combinations.generate(all_strategy_combinations_with_graph_mode()) def test_dynamic_lr(self, distribution): @@ -403,13 +498,9 @@ class TestDistributionStrategyCorrectness(test.TestCase, np.random.seed(_RANDOM_SEED) random_seed.set_random_seed(_RANDOM_SEED) - # TODO(xiejw): Change this back to 10000, once we support final partial - # batch. - num_samples = 9984 - x_train = np.random.rand(num_samples, 1).astype('float32') - y_train = 3 * x_train + x_train, y_train, _ = _dnn_training_data() - model = _create_dnn_model(compile_model=False) + model = _create_dnn_model() initial_weights = model.get_weights() update_freq = None @@ -439,9 +530,11 @@ class TestDistributionStrategyCorrectness(test.TestCase, return training_inputs, eval_inputs, predict_inputs results_with_ds = fit_eval_and_predict( - initial_weights, input_fn=input_fn, distribution=distribution) + initial_weights, input_fn=input_fn, model_fn=_create_dnn_model, + distribution=distribution) results_without_ds = fit_eval_and_predict( - initial_weights, input_fn=input_fn, distribution=None) + initial_weights, input_fn=input_fn, model_fn=_create_dnn_model, + distribution=None) compare_results(results_with_ds, results_without_ds, distribution, testcase=self) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index b99d11b923155b1362b434ffb82a96a11b08d352..40916afcfaa8c99dc6e7400f0c88aee2f668a46c 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -245,6 +245,18 @@ def all_strategy_combinations(): return strategy_minus_tpu_combinations() + tpu_strategy_combinations() +def all_strategy_combinations_minus_default(): + strategy_minus_default_combinations = combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager']) + return strategy_minus_default_combinations + tpu_strategy_combinations() + + # TODO(priyag): Add v2 optimizers here. def strategy_and_optimizer_combinations(): return combinations.times( @@ -417,15 +429,6 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_for_numpy_input_combinations()) - def test_creating_var_with_numpy_arrays(self, distribution): - with self.cached_session(): - x = np.asarray(np.random.random((64, 3)), dtype=np.float32) - var_x = distributed_training_utils.get_var_for_numpy(distribution, x) - val = self.evaluate(var_x.value()) - # Verify that the numpy value is copied to the variable. - self.assertAllEqual(x, val) - @combinations.generate(strategy_for_numpy_input_combinations()) def test_calculating_input_params_no_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -564,26 +567,26 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, metrics = ['mae'] model.compile(optimizer, loss, metrics=metrics) - inputs = np.zeros((64, 3), dtype=np.float32) - targets = np.zeros((64, 4), dtype=np.float32) + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) - # Call fit with validation data - model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, - validation_data=(inputs, targets)) + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) - # TODO(anjalisridhar): We need tests for when the batch size and steps are - # smaller and results in a 0 batch_size and steps value. - model.evaluate(inputs, targets) - # with steps - model.evaluate(inputs, targets, steps=2) - # with batch_size - model.evaluate(inputs, targets, batch_size=8) + # TODO(anjalisridhar): We need tests for when the batch size and steps + # are smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) - model.predict(inputs) - # with steps - model.predict(inputs, steps=2) - # with batch_size - model.predict(inputs, batch_size=8) + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): @@ -1149,5 +1152,37 @@ class TestDistributionStrategyWithNormalizationLayer( np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) +class TestDistributionStrategyValidation(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_layer_outside_scope(self, distribution): + with self.cached_session(): + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + with distribution.scope(): + model = keras.Model(x, y) + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_model_outside_scope(self, distribution): + with self.cached_session(): + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + with distribution.scope(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 32a0d199434e0627122fd4e47cf8894079ef3a1e..7472f6dde45a75633133b33b181db4d1a7aba287 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -122,7 +122,6 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): batches_per_update = distribution.num_replicas_in_sync self.evaluate(iterator.initializer) - self.evaluate(distribution.initialize()) self.evaluate(variables.local_variables_initializer()) batches_consumed = 0 @@ -136,8 +135,6 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): if batches_consumed >= 4: # Consume 4 input batches in total. break - self.evaluate(distribution.finalize()) - @combinations.generate(all_combinations() + tpu_combinations()) def testMean(self, distribution): def _dataset_fn(): diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 824c4b09371fcc8d590f2d2b2be8f39b4a585b27..b0e24a53f6f6a9867150f4b81a0bd3213757e0b3 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -75,7 +75,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): return distribution.run_steps_on_dataset( step_fn, iterator, iterations=2).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -84,12 +83,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): weights, biases = [], [] for _ in range(5): run_step() - weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing) @@ -152,7 +148,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # `distribution.scope`. with variable_scope.variable_creator_scope( appending_creator), distribution.scope(): - model_fn, dataset_fn, layer = minimize_loss_example( + model_fn, dataset_fn, _ = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=True, @@ -169,16 +165,12 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): return distribution.run_steps_on_dataset( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) - run_step() - self.evaluate(distribution.finalize()) - def get_expected_variables(optimizer_fn, num_parameter_devices): variables_map = { "GradientDescent": ["dense/kernel", "dense/bias"], @@ -241,7 +233,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): return distribution.run_steps_on_dataset( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -267,8 +258,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) - self.evaluate(distribution.finalize()) - @combinations.generate( combinations.times( combinations.combine( @@ -335,7 +324,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): return distribution.run_steps_on_dataset( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -370,8 +358,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) - self.evaluate(distribution.finalize()) - @combinations.generate( combinations.times( combinations.distributions_and_v1_optimizers(), @@ -458,7 +444,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): reduced=False, distribution=distribution) return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"]) - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -471,8 +456,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:])) self.assertTrue(loss_is_not_increasing) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 71e50b83b079bc73a7b178356f0f26adbd98638f..2e23a51ee56ed1388a4387a51342aabce6d24bed 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -21,8 +21,8 @@ from __future__ import print_function import functools from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_strategy -from tensorflow.python.distribute import values # pylint: disable=protected-access,invalid-name @@ -48,8 +48,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): distributed environment. There are several important concepts for distributed TensorFlow, e.g. - `client`, `job`, 'task', `cluster`, `in-graph replication` and - 'synchronous training' and they have already been defined in the + `client`, `job`, `task`, `cluster`, `in-graph replication` and + `synchronous training` and they have already been defined in the [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). The distribution strategy inherits these concepts as well and in addition to that we also clarify several more concepts: @@ -104,6 +104,61 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): auto_shard_dataset) super(MirroredStrategy, self).__init__(extended) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def make_dataset_iterator(self, dataset): # pylint: disable=useless-super-delegation + """Makes an iterator for input provided via `dataset`. + + NOTE: The batch size of the `dataset` argument is treated differently for + this contrib version of `MirroredStrategy`. + + Data from the given dataset will be distributed evenly across all the + compute replicas. We will assume that the input dataset is batched by the + per-replica batch size. + + The user could also use `make_input_fn_iterator` if they want to + customize which input is fed to which replica/worker etc. + + Args: + dataset: `tf.data.Dataset` that will be distributed evenly across all + replicas. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. + """ + return super(MirroredStrategy, self).make_dataset_iterator(dataset) + + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation + self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): + """Makes an iterator for input provided via a nest of numpy arrays. + + NOTE: The `batch_size` argument here has different behavior for this + contrib version of `MirroredStrategy`. + + Args: + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. + batch_size: The number of entries from the array we should consume in one + step of the computation, across all replicas. This is the per-replica + batch size. The global batch size will be this times + `num_replicas_in_sync`. + num_epochs: The number of times to iterate through the examples. A value + of `None` means repeat forever. + shuffle: Size of buffer to use for shuffling the input examples. + Use `None` to disable shuffling. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. + """ + return super(MirroredStrategy, self).experimental_make_numpy_iterator( + numpy_input, batch_size, num_epochs, shuffle, session) + class MirroredExtended(CoreMirroredExtended): """Implementation of (contrib) MirroredStrategy.""" @@ -135,14 +190,14 @@ class MirroredExtended(CoreMirroredExtended): Returns: An `InputIterator` which returns inputs for each step of the computation. """ - return values.DatasetIterator(dataset, self._input_workers) + return input_lib.DatasetIterator(dataset, self._input_workers) def _distribute_dataset(self, dataset_fn): if self._local_mode: - return values.PerReplicaDataset( + return input_lib.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._input_workers, 0) else: - return values.MultiWorkerDataset( + return input_lib.MultiWorkerDataset( functools.partial(self._call_dataset_fn, dataset_fn), self._input_workers, auto_shard=self._auto_shard_dataset) @@ -150,4 +205,5 @@ class MirroredExtended(CoreMirroredExtended): # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """The contrib version of Mirrored strategy uses per-replica batch size.""" return False diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 9cbc34412da2d498fa1f0624f5cba466f663c194..59d711ae0197716f1af31c24c5748bce543758dc 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -66,8 +66,10 @@ GPU_TEST = "test_gpu" in sys.argv[0] combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=["graph", "eager"])) -class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class MirroredTwoDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): def testMinimizeLoss(self, distribution): if context.executing_eagerly(): @@ -114,9 +116,30 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, expected_values) + def testNumpyIterator(self, distribution): + self._test_numpy_iterator(distribution) + def testGlobalStepUpdate(self, distribution): self._test_global_step_update(distribution) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) + def one_device_combinations(): return combinations.combine( @@ -128,25 +151,42 @@ def one_device_combinations(): mode=["graph", "eager"]) +@combinations.generate(one_device_combinations()) class MirroredOneDeviceDistributionTest( strategy_test_lib.DistributionTestBase, + strategy_test_lib.OneDeviceDistributionTestBase, parameterized.TestCase): - @combinations.generate(one_device_combinations()) def testMinimizeLoss(self, distribution): if context.executing_eagerly(): self._test_minimize_loss_eager(distribution) else: self._test_minimize_loss_graph(distribution) - @combinations.generate(one_device_combinations()) def testReplicaId(self, distribution): self._test_replica_id(distribution) - @combinations.generate(one_device_combinations()) def testCallAndMergeExceptions(self, distribution): self._test_call_and_merge_exceptions(distribution) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) + class MirroredStrategyVariableCreatorStackTest( test.TestCase, parameterized.TestCase): diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 700751d68c5a517876bf41d09223e4c1a40b50c8..836cb7cc41b62352fd69a4a209d483ccf0fc498e 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -20,6 +20,8 @@ from __future__ import print_function from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import values from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -49,10 +51,11 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): super(OneDeviceExtended, self).__init__(container_strategy) self._device = device self._default_device = device - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, [self._device])] + self._input_device = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(self._input_device, [self._device])] device_map = values.SingleDeviceMap(device) - self._input_workers = values.InputWorkers(device_map, worker_device_pairs) + self._input_workers = input_lib.InputWorkers( + device_map, worker_device_pairs) def _create_variable(self, next_creator, *args, **kwargs): colocate_with = kwargs.pop("colocate_with", None) @@ -67,19 +70,23 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): def _make_dataset_iterator(self, dataset): """Make iterator from dataset without splitting the batch.""" - return values.DatasetIterator(dataset, self._input_workers) + return input_lib.DatasetIterator(dataset, self._input_workers) def _distribute_dataset(self, dataset_fn): - return values.PerReplicaDataset( + return input_lib.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._input_workers, 0) def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - return values.InputFunctionIterator( + return input_lib.InputFunctionIterator( input_fn, self._input_workers, [distribute_lib.InputContext()]) + def _experimental_make_numpy_dataset(self, numpy_input, session): + return numpy_dataset.one_host_numpy_dataset( + numpy_input, self._input_device, session) + def _broadcast_to(self, tensor, destinations): del destinations return tensor @@ -91,7 +98,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) - ctx = values.MultiStepContext() + ctx = input_lib.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args @@ -192,6 +199,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """Global and per-replica batching are equivalent for OneDeviceStrategy.""" return True diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index d46cd6f529e363f76bfa2b22339add63530cfde8..f81466a6c75f1cf287cdb00917872f77383c615e 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -25,7 +25,9 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util -class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): +class OneDeviceStrategyTest( + strategy_test_lib.DistributionTestBase, + strategy_test_lib.OneDeviceDistributionTestBase): def _get_distribution_strategy(self): return one_device_strategy.OneDeviceStrategy("/device:CPU:0") @@ -57,6 +59,28 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): self._test_input_fn_iterator( iterator, d.extended.worker_devices, expected_values) + @test_util.run_in_graph_and_eager_modes + def testNumpyIterator(self): + self._test_numpy_iterator(self._get_distribution_strategy()) + + def testAllReduceSum(self): + self._test_all_reduce_sum(self._get_distribution_strategy()) + + def testAllReduceSumGradients(self): + self._test_all_reduce_sum_gradients(self._get_distribution_strategy()) + + def testAllReduceSumGradientTape(self): + self._test_all_reduce_sum_gradient_tape(self._get_distribution_strategy()) + + def testAllReduceMean(self): + self._test_all_reduce_mean(self._get_distribution_strategy()) + + def testAllReduceMeanGradients(self): + self._test_all_reduce_mean_gradients(self._get_distribution_strategy()) + + def testAllReduceMeanGradientTape(self): + self._test_all_reduce_mean_gradient_tape(self._get_distribution_strategy()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index a6e924b509fc15c04e15c6a5caeb6caed638395d..0cefef7545f3be1f8116fea3ffdb0e91888159de 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -18,34 +18,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib -from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.distribute import values -from tensorflow.python.eager import context -from tensorflow.python.framework import device as tf_device -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import device_setter -from tensorflow.python.util import nest +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import parameter_server_strategy +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver + +# pylint: disable=protected-access,invalid-name,line-too-long +CoreParameterServerStrategy = parameter_server_strategy.ParameterServerStrategy +CoreParameterServerExtended = parameter_server_strategy.ParameterServerStrategyExtended -_LOCAL_CPU = "/device:CPU:0" -_LOCAL_GPU_0 = "/device:GPU:0" +# pylint: enable=protected-access,invalid-name,line-too-long -# TODO(yuefengz): maybe cache variables on local CPU. -# TODO(yuefengz): we may want to set session options to disallow communication -# between workers. class ParameterServerStrategy(distribute_lib.DistributionStrategy): """A parameter server DistributionStrategy. + *** contrib version *** + This strategy class works for both local training and between-graph replicated training for multiple workers. If `cluster_spec` is specified, either passed in to __init__() method or parsed from the @@ -99,437 +89,84 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): super(ParameterServerStrategy, self).__init__( ParameterServerExtended(self, num_gpus_per_worker)) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def make_dataset_iterator(self, dataset): # pylint: disable=useless-super-delegation + """Makes an iterator for input provided via `dataset`. -class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): - """Implementation of ParameterServerStrategy.""" + NOTE: The batch size of the `dataset` argument is treated differently for + this contrib version of `ParameterServerStrategy`. - def __init__(self, container_strategy, num_gpus_per_worker): - super(ParameterServerExtended, self).__init__(container_strategy) - self._num_gpus_per_worker = num_gpus_per_worker - self._initialize_local(num_gpus_per_worker) + Data from the given dataset will be distributed evenly across all the + compute replicas. We will assume that the input dataset is batched by the + per-replica batch size. - # We typically don't need to do all-reduce in this strategy. - self._cross_device_ops = ( - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( - reduce_to_device=_LOCAL_CPU)) - - def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, - task_type, task_id): - """Initialize devices for multiple workers. - - It creates variable devices and compute devices. Variables and operations - will be assigned to them respectively. We have one compute device per - replica. The variable device is a device function or device string. The - default variable device assigns variables to parameter servers in a - round-robin fashion. + The user could also use `make_input_fn_iterator` if they want to + customize which input is fed to which replica/worker etc. Args: - num_gpus_per_worker: number of local GPUs or GPUs per worker. - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type. - task_id: the current task id. + dataset: `tf.data.Dataset` that will be distributed evenly across all + replicas. - Raises: - ValueError: if the cluster_spec doesn't have ps jobs. + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. """ - assert cluster_spec - if not task_type or task_id is None: - raise ValueError("When `cluster_spec` is given, you must also specify " - "`task_type` and `task_id`") - cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - - worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id) - - # Define compute devices which is a list of device strings and one for each - # replica. When there are GPUs, replicate operations on these GPUs. - # Otherwise, place operations on CPU. - if num_gpus_per_worker > 0: - compute_devices = tuple( - "%s/device:GPU:%d" % (worker_device, i) - for i in range(num_gpus_per_worker) - ) - else: - compute_devices = (worker_device,) - - self._device_map = values.ReplicaDeviceMap(compute_devices) - self._input_workers = values.InputWorkers( - self._device_map, [(worker_device, compute_devices)]) - - # In distributed mode, place variables on ps jobs in a round-robin fashion. - # Note that devices returned from `replica_device_setter` are not - # canonical and therefore we don't canonicalize all variable devices to - # make them consistent. - # TODO(yuefengz): support passing a strategy object to control variable - # assignment. - # TODO(yuefengz): merge the logic of replica_device_setter into this - # class. - num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) - if num_ps_replicas == 0: - raise ValueError("The cluster spec needs to have `ps` jobs.") - self._variable_device = device_setter.replica_device_setter( - ps_tasks=num_ps_replicas, - worker_device=worker_device, - merge_devices=True, - cluster=cluster_spec) - - # The `_parameter_devices` is needed for the `parameter_devices` property - # and is a list of all variable devices. Here parameter devices are all - # tasks of the "ps" job. - self._parameter_devices = tuple(map("/job:ps/task:{}".format, - range(num_ps_replicas))) - - # Add a default device so that ops without specified devices will not end up - # on other workers. - self._default_device = worker_device - - self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, - task_id) - self._cluster_spec = cluster_spec - self._task_type = task_type - self._task_id = task_id - - logging.info( - "Multi-worker ParameterServerStrategy with " - "cluster_spec = %r, task_type = %r, task_id = %r, " - "num_ps_replicas = %r, is_chief = %r, device_map = %r, " - "variable_device = %r", cluster_spec.as_dict(), task_type, task_id, - num_ps_replicas, self._is_chief, self._device_map, - self._variable_device) - - def _initialize_local(self, num_gpus_per_worker): - """Initialize internal devices for local training.""" - worker_device = device_util.canonicalize("/device:CPU:0") - # Define compute devices which is a list of device strings and one for each - # replica. When there are GPUs, replicate operations on these GPUs. - # Otherwise, place operations on CPU. - if num_gpus_per_worker > 0: - compute_devices = tuple( - map("/device:GPU:{}".format, range(num_gpus_per_worker))) - else: - compute_devices = (_LOCAL_CPU,) - - self._device_map = values.ReplicaDeviceMap(compute_devices) - self._input_workers = values.InputWorkers( - self._device_map, [(worker_device, compute_devices)]) - - # If there is only one GPU, put everything on that GPU. Otherwise, place - # variables on CPU. - if num_gpus_per_worker == 1: - assert len(compute_devices) == 1 - self._variable_device = _LOCAL_GPU_0 - self._parameter_devices = (_LOCAL_GPU_0,) - else: - self._variable_device = _LOCAL_CPU - self._parameter_devices = (_LOCAL_CPU,) - - self._is_chief = True - self._cluster_spec = None - self._task_type = None - self._task_id = None - - logging.info( - "ParameterServerStrategy with compute_devices = %r, " - "variable_device = %r", compute_devices, self._variable_device) - - def _validate_colocate_with_variable(self, colocate_with_variable): - values.validate_colocate(colocate_with_variable, self) - - def _distribute_dataset(self, dataset_fn): - """Distributes the dataset to each local GPU.""" - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._input_workers, 0, - prefetch_on_device=True) - - def _make_dataset_iterator(self, dataset): - return values.DatasetIterator(dataset, self._input_workers, - self._num_replicas_in_sync) - - def _make_input_fn_iterator( - self, - input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - """Distributes the dataset to each local GPU.""" - if self._cluster_spec: - input_pipeline_id = multi_worker_util.id_in_cluster( - self._cluster_spec, self._task_type, self._task_id) - num_input_pipelines = multi_worker_util.worker_count( - self._cluster_spec, self._task_type) - else: - input_pipeline_id = 0 - num_input_pipelines = 1 - input_context = distribute_lib.InputContext( - num_input_pipelines=num_input_pipelines, - input_pipeline_id=input_pipeline_id, - num_replicas_in_sync=self._num_replicas_in_sync) - return values.InputFunctionIterator( - input_fn, self._input_workers, [input_context]) - - def _broadcast_to(self, tensor, destinations): - # This is both a fast path for Python constants, and a way to delay - # converting Python values to a tensor until we know what type it - # should be converted to. Otherwise we have trouble with: - # global_step.assign_add(1) - # since the `1` gets broadcast as an int32 but global_step is int64. - if isinstance(tensor, (float, int)): - return tensor - if not cross_device_ops_lib.check_destinations(destinations): - # TODO(josh11b): Use current logical device instead of 0 here. - destinations = values.LogicalDeviceSpec( - device_map=self._device_map, logical_device=0) - return self._cross_device_ops.broadcast(tensor, destinations) - - def _allow_variable_partition(self): - return not context.executing_eagerly() - - # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through - # this creator, such as "MutableHashTable". - def _create_variable(self, next_creator, *args, **kwargs): - if self._num_replicas_in_sync > 1: - aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) - if aggregation not in ( - vs.VariableAggregation.NONE, - vs.VariableAggregation.SUM, - vs.VariableAggregation.MEAN, - vs.VariableAggregation.ONLY_FIRST_REPLICA - ): - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) - - def var_creator(*args, **kwargs): - """Create an AggregatingVariable and fix up collections.""" - # Record what collections this variable should be added to. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - - # Create and wrap the variable. - v = next_creator(*args, **kwargs) - wrapped = values.AggregatingVariable( - self._container_strategy(), v, aggregation) - - # Add the wrapped variable to the requested collections. - # The handling of eager mode and the global step matches - # ResourceVariable._init_from_args(). - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the contained - # variable to the TRAINABLE_VARIABLES collection, so we manually - # remove it and replace with the wrapper. We can't set "trainable" - # to False for next_creator() since that causes functions like - # implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - l.remove(v) - g.add_to_collections(collections, wrapped) - elif ops.GraphKeys.GLOBAL_STEP in collections: - ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) - - return wrapped - else: - var_creator = next_creator - - if "colocate_with" in kwargs: - with ops.device(None): - with ops.colocate_with(kwargs["colocate_with"]): - return var_creator(*args, **kwargs) - - with ops.colocate_with(None, ignore_existing=True): - with ops.device(self._variable_device): - return var_creator(*args, **kwargs) - - def _call_for_each_replica(self, fn, args, kwargs): - # pylint: disable=protected-access - return mirrored_strategy._call_for_each_replica( - self._container_strategy(), self._device_map, fn, args, kwargs) + return super(ParameterServerStrategy, self).make_dataset_iterator(dataset) - def _verify_destinations_not_different_worker(self, destinations): - if not self._cluster_spec: - return - if destinations is None: - return - for d in cross_device_ops_lib.get_devices_from(destinations): - d_spec = tf_device.DeviceSpec.from_string(d) - if d_spec.job == self._task_type and d_spec.task != self._task_id: - raise ValueError( - "Cannot reduce to another worker: %r, current worker is %r" % - (d, self._input_workers.worker_devices[0])) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation + self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): + """Makes an iterator for input provided via a nest of numpy arrays. - def _reduce_to(self, reduce_op, value, destinations): - self._verify_destinations_not_different_worker(destinations) - if not isinstance(value, values.DistributedValues): - # pylint: disable=protected-access - return cross_device_ops_lib.reduce_non_distributed_value( - reduce_op, self._device_map, value, destinations) - return self._cross_device_ops.reduce( - reduce_op, value, destinations=destinations) - - def _batch_reduce_to(self, reduce_op, value_destination_pairs): - for _, destinations in value_destination_pairs: - self._verify_destinations_not_different_worker(destinations) - return self._cross_device_ops.batch_reduce(reduce_op, - value_destination_pairs) - - def _select_single_value(self, structured): - """Select any single values in `structured`.""" - - def _select_fn(x): # pylint: disable=g-missing-docstring - if isinstance(x, values.Mirrored): - if len(x.devices) == 1: - return x.primary - else: - raise ValueError( - "You cannot update variable with a Mirrored object with multiple " - "components %r when using ParameterServerStrategy. You must " - "specify a single value or a Mirrored with a single value." % x) - elif isinstance(x, values.PerReplica): - raise ValueError( - "You cannot update variable with a PerReplica object %r when using " - "ParameterServerStrategy. You must specify a single value or a " - "Mirrored with a single value" % x) - else: - return x - - return nest.map_structure(_select_fn, structured) - - def _update(self, var, fn, args, kwargs, group): - if isinstance(var, values.AggregatingVariable): - var = var.get() - if not isinstance(var, resource_variable_ops.ResourceVariable): - raise ValueError( - "You can not update `var` %r. It must be a Variable." % var) - with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): - result = fn(var, *self._select_single_value(args), - **self._select_single_value(kwargs)) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - # TODO(yuefengz): does it need to call _select_single_value? - def _update_non_slot(self, colocate_with, fn, args, kwargs, group): - with ops.device( - colocate_with.device), distribute_lib.UpdateContext(colocate_with): - result = fn(*args, **kwargs) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - def _unwrap(self, val): - if isinstance(val, values.DistributedValues): - return val.values - return (val,) - - def value_container(self, val): - if (hasattr(val, "_aggregating_container") and - not isinstance(val, values.AggregatingVariable)): - wrapper = val._aggregating_container() # pylint: disable=protected-access - if wrapper is not None: - return wrapper - return val - - def read_var(self, var): - # No need to distinguish between normal variables and replica-local - # variables. - return array_ops.identity(var) - - def _configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - """Configures the strategy class. - - The strategy object will be re-initialized if `cluster_spec` is given but - was not passed in the constructor. + NOTE: The `batch_size` argument here has different behavior for this + contrib version of `ParameterServerStrategy`. Args: - session_config: not used currently. - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type. - task_id: the current task id. - - Raises: - ValueError: if `cluster_spec` is given but `task_type` or `task_id` is - not. + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. + batch_size: The number of entries from the array we should consume in one + step of the computation, across all replicas. This is the per-replica + batch size. The global batch size will be this times + `num_replicas_in_sync`. + num_epochs: The number of times to iterate through the examples. A value + of `None` means repeat forever. + shuffle: Size of buffer to use for shuffling the input examples. + Use `None` to disable shuffling. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. """ - if not self._cluster_spec and cluster_spec: - # If a `cluster_spec` is already passed in, do nothing here. - # TODO(yuefengz): check `cluster_spec` is the same if this object has - # already been initialized with a `cluster_spec`. - if task_type is None or task_id is None: - raise ValueError("When `cluster_spec` is given, must also specify " - "`task_type` and `task_id`.") - self._cluster_spec = multi_worker_util.normalize_cluster_spec( - cluster_spec) - self._task_type = task_type - self._task_id = task_id - self._initialize_multi_worker(self._num_gpus_per_worker, - self._cluster_spec, task_type, task_id) - - if session_config: - session_config.CopyFrom(self._update_config_proto(session_config)) - - def _update_config_proto(self, config_proto): - updated_config = copy.deepcopy(config_proto) - if not self._cluster_spec: - updated_config.isolate_session_state = True - return updated_config - - updated_config.isolate_session_state = False + return super(ParameterServerStrategy, + self).experimental_make_numpy_iterator( + numpy_input, batch_size, num_epochs, shuffle, session) - assert self._task_type - assert self._task_id is not None - # The device filters prevent communication between workers. - if self._task_type not in ["chief", "worker"]: - return updated_config - del updated_config.device_filters[:] - updated_config.device_filters.extend( - ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) - return updated_config - - @property - def _num_replicas_in_sync(self): - return self._device_map.num_replicas_in_graph - - @property - def worker_devices(self): - return self._device_map.all_devices - - @property - def worker_devices_by_replica(self): - return self._device_map.devices_by_replica - - @property - def parameter_devices(self): - return self._parameter_devices - - def non_slot_devices(self, var_list): - return min(var_list, key=lambda x: x.name) - - @property - def experimental_between_graph(self): - # TODO(yuefengz): Should this return False in the local case? - return True - - @property - def experimental_should_init(self): - return self._is_chief +class ParameterServerExtended(CoreParameterServerExtended): + """Implementation of ParameterServerStrategy.""" - @property - def should_checkpoint(self): - return self._is_chief + def __init__(self, container_strategy, num_gpus_per_worker): + # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change + # the constructor's interface to allow customized cluster resolver. Use + # SimpleClusterResolver to override num_accelerators. + tfconfig = TFConfigClusterResolver() + cluster_resolver = SimpleClusterResolver( + cluster_spec=tfconfig.cluster_spec(), + task_type=tfconfig.task_type, + task_index=tfconfig.task_index, + num_accelerators=num_gpus_per_worker) + super(ParameterServerExtended, self).__init__( + container_strategy, cluster_resolver=cluster_resolver) - @property - def should_save_summary(self): - return self._is_chief + def _make_dataset_iterator(self, dataset): + return input_lib.DatasetIterator(dataset, self._input_workers) # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """The contrib version of PS strategy uses per-replica batch size.""" return False diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index ce7065f220541ce2437f4768c312365a35197ab4..802809e7c7ec037ca3238c3e615308c701368d2a 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -29,10 +29,13 @@ from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import parameter_server_strategy as core_parameter_server_strategy from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.estimator import run_config @@ -50,6 +53,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import training_util +from tensorflow.python.training.server_lib import ClusterSpec CHIEF = run_config.TaskType.CHIEF WORKER = run_config.TaskType.WORKER @@ -63,6 +67,57 @@ def _get_replica_id_integer(): return replica_id +class MockCoreParameterServerStrategy(distribute_lib.DistributionStrategy): + """Mock the strategy to allow cluster resolver as an argument.""" + + def __init__(self, cluster_resolver): + super(MockCoreParameterServerStrategy, self).__init__( + core_parameter_server_strategy.ParameterServerStrategyExtended( + self, cluster_resolver=cluster_resolver)) + + +def create_test_objects(cluster_spec=None, + task_type=None, + task_id=None, + num_gpus=None, + sess_config=None, + use_core_strategy=False): + sess_config = sess_config or config_pb2.ConfigProto() + if num_gpus is None: + num_gpus = context.num_gpus() + if use_core_strategy: + if cluster_spec and task_type and task_id is not None: + cluster_resolver = SimpleClusterResolver( + cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), + task_type=task_type, + task_index=task_id, + num_accelerators=num_gpus) + target = 'grpc://' + cluster_spec[WORKER][task_id] + else: + cluster_resolver = SimpleClusterResolver( + ClusterSpec({}), num_accelerators=num_gpus) + target = '' + + distribution = MockCoreParameterServerStrategy(cluster_resolver) + sess_config = copy.deepcopy(sess_config) + sess_config = distribution.update_config_proto(sess_config) + else: + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=num_gpus) + if task_type: + sess_config = copy.deepcopy(sess_config) + distribution.configure( + session_config=sess_config, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id) + target = 'grpc://' + cluster_spec[WORKER][task_id] + else: + target = '' + + return distribution, target, sess_config + + class ParameterServerStrategyTestBase( multi_worker_test_base.MultiWorkerTestBase): @@ -76,24 +131,27 @@ class ParameterServerStrategyTestBase( self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True) super(ParameterServerStrategyTestBase, self).setUp() - def _get_test_objects(self, task_type, task_id, num_gpus): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=num_gpus) - if not task_type: - return distribution, '', self._sess_config - - sess_config = copy.deepcopy(self._sess_config) - distribution.configure( - session_config=sess_config, + def _get_test_objects(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): + return create_test_objects( cluster_spec=self._cluster_spec, task_type=task_type, - task_id=task_id) - return (distribution, 'grpc://' + self._cluster_spec[WORKER][task_id], - sess_config) - - def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): + task_id=task_id, + num_gpus=num_gpus, + sess_config=self._sess_config, + use_core_strategy=use_core_strategy) + + def _test_device_assignment_distributed(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) - d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) + d, _, sess_config = self._get_test_objects( + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) with ops.Graph().as_default(), \ self.cached_session(target=self._default_target, config=sess_config) as sess, \ @@ -191,8 +249,9 @@ class ParameterServerStrategyTestBase( self.assertEqual(f_val, 46.0) def _test_device_assignment_distributed_enable_partitioner( - self, task_type, task_id, num_gpus): - d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) + self, task_type, task_id, num_gpus, use_core_strategy=False): + d, _, sess_config = self._get_test_objects( + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) num_shards = len(d.parameter_devices) partitioner = partitioned_variables.fixed_size_partitioner(num_shards) with ops.Graph().as_default(), \ @@ -340,9 +399,13 @@ class ParameterServerStrategyTestBase( self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) - def _test_simple_increment(self, task_type, task_id, num_gpus): + def _test_simple_increment(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, sess_config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) if d.extended._cluster_spec: num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) if 'chief' in d.extended._cluster_spec.as_dict(): @@ -410,9 +473,13 @@ class ParameterServerStrategyTestBase( y_val == 20.0 + 1.0 * num_workers * d.num_replicas_in_sync and z_val == 30.0 + 1.0 * num_workers) - def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + def _test_minimize_loss_graph(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, sess_config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) if task_type: # Multi-worker assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec @@ -498,10 +565,15 @@ class ParameterServerStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before - def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, - expected_values): + def _test_input_fn_iterator(self, + task_type, + task_id, + num_gpus, + input_fn, + expected_values, + use_core_strategy=False): distribution, master_target, config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) devices = distribution.extended.worker_devices with ops.Graph().as_default(), \ @@ -531,9 +603,11 @@ class ParameterServerStrategyTestBase( self.assertEqual(expected_value, computed_value) -class ParameterServerStrategyTest(ParameterServerStrategyTestBase, - strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class ParameterServerStrategyTest( + ParameterServerStrategyTestBase, + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): @classmethod def setUpClass(cls): @@ -541,66 +615,93 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, num_workers=3, num_ps=2) cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0] - def test_num_replicas_in_sync(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def test_num_replicas_in_sync(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) # All the devices on a given worker are in sync which in this case is the # number of gpus on each worker. - self.assertEqual(2, distribution.num_replicas_in_sync) + self.assertEqual(2, strategy.num_replicas_in_sync) - def testDeviceAssignmentLocalCPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=0) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalCPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=0, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) + strategy, compute_device='CPU', variable_device='CPU', num_gpus=0) - def testDeviceAssignmentLocalOneGPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=1) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalOneGPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=1, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) + strategy, compute_device='GPU', variable_device='GPU', num_gpus=1) - def testDeviceAssignmentLocalTwoGPUs(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalTwoGPUs(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) + strategy, compute_device='GPU', variable_device='CPU', num_gpus=2) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testDeviceAssignmentDistributed(self, num_gpus): - self._test_device_assignment_distributed('worker', 1, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testDeviceAssignmentDistributed(self, num_gpus, use_core_strategy): + self._test_device_assignment_distributed( + 'worker', 1, num_gpus, use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus): + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus, + use_core_strategy): self._test_device_assignment_distributed_enable_partitioner( - 'worker', 1, num_gpus) + 'worker', 1, num_gpus, use_core_strategy=use_core_strategy) - def testSimpleBetweenGraph(self): - self._run_between_graph_clients(self._test_simple_increment, - self._cluster_spec, context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testSimpleBetweenGraph(self, use_core_strategy): + self._run_between_graph_clients( + self._test_simple_increment, + self._cluster_spec, + context.num_gpus(), + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testLocalSimpleIncrement(self, num_gpus): - self._test_simple_increment(None, 0, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testLocalSimpleIncrement(self, num_gpus, use_core_strategy): + self._test_simple_increment(None, 0, num_gpus, use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraphDistributed(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraphDistributed(self, num_gpus, use_core_strategy): + self._run_between_graph_clients( + self._test_minimize_loss_graph, + self._cluster_spec, + num_gpus, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraphLocal(self, num_gpus): - self._test_minimize_loss_graph(None, None, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraphLocal(self, num_gpus, use_core_strategy): + self._test_minimize_loss_graph(None, None, num_gpus, use_core_strategy) # TODO(priyag): Refactor this and other multi worker tests. @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) - def testMakeInputFnIteratorDistributed(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testMakeInputFnIteratorDistributed(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') dataset_fn = lambda: dataset_ops.Dataset.range(100) @@ -612,12 +713,21 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=3, expected_input_pipeline_id=1) # because task_id = 1 - self._test_input_fn_iterator('worker', 1, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + 'worker', + 1, + num_gpus, + input_fn, + expected_values, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) - def testMakeInputFnIteratorLocal(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') dataset_fn = lambda: dataset_ops.Dataset.range(100) @@ -629,23 +739,31 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=1, expected_input_pipeline_id=0) # only one worker and pipeline for local. - self._test_input_fn_iterator(None, None, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + None, + None, + num_gpus, + input_fn, + expected_values, + use_core_strategy=use_core_strategy) - def testGlobalStepUpdate(self): - strategy = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepUpdate(self, use_core_strategy): + strategy, _, _ = create_test_objects(use_core_strategy=use_core_strategy) self._test_global_step_update(strategy) - def testUpdateConfigProtoMultiWorker(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - distribution.configure( + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testUpdateConfigProtoMultiWorker(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + strategy.configure( cluster_spec=self._cluster_spec, task_type='worker', task_id=1) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) - new_config = distribution.update_config_proto(config_proto) + new_config = strategy.update_config_proto(config_proto) # Verify device filters. self.assertEqual(['/job:worker/task:1', '/job:ps'], @@ -654,16 +772,48 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, # Verify isolate_session_state self.assertFalse(new_config.isolate_session_state) - def testUpdateConfigProtoLocal(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testUpdateConfigProtoLocal(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) config_proto = config_pb2.ConfigProto() - new_config = distribution.update_config_proto(config_proto) + new_config = strategy.update_config_proto(config_proto) # Verify isolate_session_state self.assertTrue(new_config.isolate_session_state) + def testAllReduceSum(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean_gradient_tape(distribution) + class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, parameterized.TestCase): @@ -674,20 +824,31 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, num_workers=3, num_ps=2, has_chief=True) cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0] - def testSimpleBetweenGraph(self): - self._run_between_graph_clients(self._test_simple_increment, - self._cluster_spec, context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testSimpleBetweenGraph(self, use_core_strategy): + self._run_between_graph_clients( + self._test_simple_increment, + self._cluster_spec, + context.num_gpus(), + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraph(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraph(self, num_gpus, use_core_strategy): + self._run_between_graph_clients( + self._test_minimize_loss_graph, + self._cluster_spec, + num_gpus, + use_core_strategy=use_core_strategy) - def testGlobalStepIsWrappedOnTwoGPUs(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - with ops.Graph().as_default(), distribution.scope(): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepIsWrappedOnTwoGPUs(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): created_step = training_util.create_global_step() get_step = training_util.get_global_step() self.assertEqual(created_step, get_step, @@ -696,12 +857,14 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, id(get_step), get_step.__class__.__name__))) self.assertIs(values.AggregatingVariable, type(created_step)) self.assertIs(values.AggregatingVariable, type(get_step)) - self.assertIs(distribution, created_step.distribute_strategy) + self.assertIs(strategy, created_step.distribute_strategy) - def testGlobalStepIsNotWrappedOnOneGPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=1) - with ops.Graph().as_default(), distribution.scope(): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepIsNotWrappedOnOneGPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=1, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): created_step = training_util.create_global_step() get_step = training_util.get_global_step() self.assertEqual(created_step, get_step, @@ -710,20 +873,36 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, id(get_step), get_step.__class__.__name__))) self.assertIs(resource_variable_ops.ResourceVariable, type(created_step)) self.assertIs(resource_variable_ops.ResourceVariable, type(get_step)) - self.assertIs(distribution, created_step.distribute_strategy) + self.assertIs(strategy, created_step.distribute_strategy) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testValueContainer(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): - def testValueContainer(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - with ops.Graph().as_default(), distribution.scope(): def f(): with backprop.GradientTape() as tape: v = variable_scope.get_variable('v', initializer=10.0) _ = v * v v, = tape.watched_variables() - w = distribution.extended.value_container(v) + w = strategy.extended.value_container(v) self.assertIs(values.AggregatingVariable, type(w)) - distribution.extended.call_for_each_replica(f) + + strategy.extended.call_for_each_replica(f) + + +class LocalParameterServerStrategyTest(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine(mode=['graph', 'eager'], + use_core_strategy=[True, False], + required_gpus=2)) + def testNumpyIterator(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + self._test_numpy_iterator(strategy) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 1ff9b9ceec13351b098d47ed3ff62f689a625a31..a77d6d0bec85d321666f9c4043d22fc6a0c37158 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -45,7 +45,6 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): single_loss_step, layer = single_loss_example( optimizer_fn, distribution, use_bias=True, iterations_per_step=2) - self.evaluate(distribution.initialize()) if context.executing_eagerly(): run_step = single_loss_step else: @@ -57,12 +56,9 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): weights, biases = [], [] for _ in range(5): run_step() - weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 6e5280e35632d3f3cb6a4fe172a15fb7f508354c..7455cbd02a2571ed3bae81864580440c982e5f7b 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values @@ -31,6 +34,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -292,3 +296,190 @@ class DistributionTestBase(test.TestCase): global_step_tensors = strategy.unwrap(value) global_step_values = self.evaluate(global_step_tensors) self.assertEqual((1,) * len(global_step_tensors), global_step_values) + + def _test_numpy_iterator(self, strategy): + with strategy.scope(), self.cached_session() as sess: + x = np.asarray([[1, 2], [6, 12], [2, 4], + [5, 10], [3, 6], [4, 8]]) + y = np.asarray([5, 4, 3, 2, 1, 0]) + batch_size = 6 + if not strategy.extended._global_batch_size: # pylint: disable=protected-access + batch_size = batch_size // strategy.num_replicas_in_sync + i = strategy.experimental_make_numpy_iterator( + (x, y), batch_size=batch_size, num_epochs=2, shuffle=None, + session=sess) + self.evaluate(i.initialize()) + + def run_and_concatenate(strategy, i): + x, y = strategy.experimental_run(lambda z: z, i) + x, y = self.evaluate((strategy.unwrap(x), strategy.unwrap(y))) + return np.concatenate(x), np.concatenate(y) + + x_1, y_1 = run_and_concatenate(strategy, i) + self.assertAllEqual(x, x_1) + self.assertAllEqual(y, y_1) + x_2, y_2 = run_and_concatenate(strategy, i) + self.assertAllEqual(x, x_2) + self.assertAllEqual(y, y_2) + with self.assertRaises(errors.OutOfRangeError): + run_and_concatenate(strategy, i) + + +class OneDeviceDistributionTestBase(test.TestCase): + """Some tests that should work with any one-device DistributionStrategy.""" + + def _test_all_reduce_sum(self, strategy): + self._test_collective_comms( + strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.])) + + def _test_all_reduce_sum_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_sum, inputs=[4.], expected_grads=[4.]) + + def _test_all_reduce_sum_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_sum, inputs=[4.], expected_grads=[4.]) + + def _test_all_reduce_mean(self, strategy): + self._test_collective_comms( + strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.])) + + def _test_all_reduce_mean_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_mean, inputs=[5.], expected_grads=[5.]) + + def _test_all_reduce_mean_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_mean, inputs=[5.], expected_grads=[5.]) + + def _test_collective_comms(self, strategy, comm_fn, inputs, expected): + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + outputs = self.evaluate( + list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + self.assertAllEqual([expected[0]], outputs[0]) + self.assertAllEqual([expected[1]], outputs[1]) + + def _test_collective_comms_gradients( + self, strategy, comm_fn, inputs, expected_grads): + if context.executing_eagerly(): + self.skipTest("`tf.gradients` is not supported with eager execution.") + + def step(c): + x = constant_op.constant(42.) + y = comm_fn(x) * c + return gradients_impl.gradients(y, [x])[0] + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + def _test_collective_comms_gradient_tape( + self, strategy, comm_fn, inputs, expected_grads): + def step(c): + x = constant_op.constant(42.) + with backprop.GradientTape() as tape: + tape.watch(x) + y = comm_fn(x) * c + return tape.gradient(y, x) + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + +class TwoDeviceDistributionTestBase(test.TestCase): + """Some tests that should work with any two-device DistributionStrategy.""" + + def _test_all_reduce_sum(self, strategy): + self._test_collective_comms( + strategy, _all_sum, + inputs=([1., 3.], [[39., 2.], [3., 41.]]), + expected=(4., [42., 43.])) + + def _test_all_reduce_sum_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) + + def _test_all_reduce_sum_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) + + def _test_all_reduce_mean(self, strategy): + self._test_collective_comms( + strategy, _all_mean, + inputs=([1., 3.], [[39., 2.], [3., 41.]]), + expected=(2., [21., 21.5])) + + def _test_all_reduce_mean_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) + + def _test_all_reduce_mean_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) + + def _test_collective_comms(self, strategy, comm_fn, inputs, expected): + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + outputs = self.evaluate( + list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + self.assertAllEqual([expected[0], expected[0]], outputs[0]) + self.assertAllEqual([expected[1], expected[1]], outputs[1]) + + def _test_collective_comms_gradients( + self, strategy, comm_fn, inputs, expected_grads): + if context.executing_eagerly(): + self.skipTest("`tf.gradients` is not supported with eager execution.") + + def step(c): + x = constant_op.constant(42.) + y = comm_fn(x) * c + return gradients_impl.gradients(y, [x])[0] + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + def _test_collective_comms_gradient_tape( + self, strategy, comm_fn, inputs, expected_grads): + def step(c): + x = constant_op.constant(42.) + with backprop.GradientTape() as tape: + tape.watch(x) + y = comm_fn(x) * c + return tape.gradient(y, x) + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + +def _all_sum(value): + ctx = ds_context.get_replica_context() + return ctx.all_reduce(reduce_util.ReduceOp.SUM, value) + + +def _all_mean(value): + ctx = ds_context.get_replica_context() + return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 9e465f30c1b14dd1602919442aaffc5991ddeaf4..518c704b8990fbadab6a9707dd10611fd878e8ce 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -33,6 +33,8 @@ from tensorflow.python.client import session as session_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver_lib @@ -50,6 +52,26 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +def initialize_tpu_system(cluster_resolver=None): + """Initialize the TPU devices in a separate session and graph. + + Args: + cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, + which provides information about the TPU cluster. + """ + if cluster_resolver is None: + cluster_resolver = resolver_lib.TPUClusterResolver("") + master = cluster_resolver.master() + + logging.info("Initializing the TPU system.") + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + + with ops.Graph().as_default(): + with session_lib.Session(config=session_config, target=master) as sess: + sess.run([tpu.initialize_system()]) + logging.info("Finished initializing TPU system.") + + def get_tpu_system_metadata(tpu_cluster_resolver): """Retrieves TPU system metadata given a TPUClusterResolver.""" master = tpu_cluster_resolver.master() @@ -126,7 +148,8 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def __init__(self, tpu_cluster_resolver=None, - steps_per_run=None): + steps_per_run=None, + **kwargs): """Initializes the TPUStrategy object. Args: @@ -137,10 +160,22 @@ class TPUStrategy(distribute_lib.DistributionStrategy): metrics, summaries etc. This parameter is only used when Distribution Strategy is used with estimator or keras. + **kwargs: Additional experimental flags. Will be removed in future. """ super(TPUStrategy, self).__init__(TPUExtended( self, tpu_cluster_resolver, steps_per_run)) + self._disable_training_loop_on_host = False + if len(kwargs) > 1: + raise ValueError("TPUStrategy constructor only takes one experimental " + "flag now") + if len(kwargs) == 1: + if "_disable_training_loop_on_host" not in kwargs: + raise ValueError("TPUStrategy constructor does not support arguments: " + "{}".format(kwargs)) + self._disable_training_loop_on_host = ( + kwargs["_disable_training_loop_on_host"]) + @property def steps_per_run(self): """DEPRECATED: use .extended.steps_per_run instead.""" @@ -150,11 +185,6 @@ class TPUStrategy(distribute_lib.DistributionStrategy): class TPUExtended(distribute_lib.DistributionStrategyExtended): """Implementation of TPUStrategy.""" - # Track what TPU devices have been initialized. This is *intentionally* - # shared across all instances of TPUExtended as we want to keep track of which - # devices are initialized globally. - _initialized_devices = [] - def __init__(self, container_strategy, tpu_cluster_resolver=None, @@ -191,39 +221,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): (self.get_host(hid), [self.get_host_cpu_device(hid)]) for hid in range(self.num_hosts) ] - self._input_workers = values.InputWorkers(input_device_map, worker_devices) + self._input_workers = input_lib.InputWorkers( + input_device_map, worker_devices) # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run self._require_static_shapes = True - # Initialize the TPU devices. - self._initialize_tpu() - - def _initialize_tpu(self): - """Initialize the TPU devices in a separate session and graph. - - We keep track of all the TPU devices that we're initialized as we should - only be running TPU initialize once for the entire process. - """ - master = self._tpu_cluster_resolver.master() - # Verify TPU has not already been initialized in this process. - if master in TPUExtended._initialized_devices: - logging.info("TPU master %s has already been initialized." % master) - return - - logging.info("Initializing the TPU system.") - session_config = config_pb2.ConfigProto(allow_soft_placement=True) - self._configure(session_config) - with ops.Graph().as_default(): - with session_lib.Session(config=session_config, target=master) as sess: - sess.run([tpu.initialize_system()]) - logging.info("Finized initializing TPU system.") - - # Update Strategy state to make sure we can track device initialization. - TPUExtended._initialized_devices.append(master) - def _validate_colocate_with_variable(self, colocate_with_variable): values.validate_colocate_tpu_variable(colocate_with_variable, self) @@ -290,15 +295,19 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def _make_dataset_iterator(self, dataset): """Make iterators for each of the TPU hosts.""" - - return values.DatasetIterator(dataset, self._input_workers, - self._num_replicas_in_sync) + return input_lib.DatasetIterator(dataset, self._input_workers, + self._num_replicas_in_sync) def _distribute_dataset(self, dataset_fn): - return values.MultiWorkerDataset( + return input_lib.MultiWorkerDataset( functools.partial(self._call_dataset_fn, dataset_fn), self._input_workers) + def _experimental_make_numpy_dataset(self, numpy_input, session): + return numpy_dataset.one_host_numpy_dataset( + numpy_input, numpy_dataset.SingleDevice(self.get_host_cpu_device(0)), + session) + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. @@ -326,10 +335,11 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) - ctx = values.MultiStepContext() + ctx = input_lib.MultiStepContext() - def run_fn(): + def run_fn(*args, **kwargs): """Single step on the TPU device.""" + del args, kwargs fn_result = fn(ctx, dequeue_fn()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: @@ -338,6 +348,9 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): else: return fn_result + def iterate_on_tpu(): + return training_loop.repeat(iterations, run_fn, initial_loop_values) + # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer @@ -347,56 +360,77 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - def rewrite_fn(*args): - """The rewritten step fn running on TPU.""" - del args + # pylint: disable=protected-access + if self._container_strategy()._disable_training_loop_on_host: replicate_inputs = [[]] * self._num_replicas_in_sync - replicate_outputs = tpu.replicate(run_fn, replicate_inputs) - - # If run_fn has tensor outputs, tpu.replicate returns a list of list. We - # will flatten it in this case. If run_fn has no tensor outputs, - # tpu.replicate returns a list of no_ops, we will keep the output as it - # is. - if isinstance(replicate_outputs[0], list): - replicate_outputs = nest.flatten(replicate_outputs) - - return replicate_outputs - - # TODO(sourabhbajaj): The input to while loop should be based on the output - # type of the step_fn - assert isinstance(initial_loop_values, list) - initial_loop_values = initial_loop_values * self._num_replicas_in_sync - - # Put the while loop op on host 0. - with ops.device(self.get_host_cpu_device(0)): - replicate_outputs = training_loop.repeat(iterations, rewrite_fn, - initial_loop_values) + replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + else: + def rewrite_fn(*args): + """The rewritten step fn running on TPU.""" + del args + replicate_inputs = [[]] * self._num_replicas_in_sync + replicate_outputs = tpu.replicate(run_fn, replicate_inputs) + + # If run_fn has tensor outputs, tpu.replicate returns a list of list. We + # will flatten it in this case. If run_fn has no tensor outputs, + # tpu.replicate returns a list of no_ops, we will keep the output as it + # is. + if isinstance(replicate_outputs[0], list): + replicate_outputs = nest.flatten(replicate_outputs) + + return replicate_outputs + + # TODO(sourabhbajaj): The input to while loop should be based on the + # output type of the step_fn + assert isinstance(initial_loop_values, list) + initial_loop_values = initial_loop_values * self._num_replicas_in_sync + + # Put the while loop op on host 0. + with ops.device(self.get_host_cpu_device(0)): + replicate_outputs = training_loop.repeat(iterations, rewrite_fn, + initial_loop_values) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) - if isinstance(replicate_outputs, list): + if self._container_strategy()._disable_training_loop_on_host: # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. - last_step_tensor_outputs = [ - x for x in replicate_outputs if not isinstance(x, ops.Operation) - ] - - # Outputs are currently of the structure (flattened) - # [output0_device0, output1_device0, output2_device0, - # output0_device1, output1_device1, output2_device1, - # ...] + last_step_tensor_outputs = [x for x in replicate_outputs + if not isinstance(x, ops.Operation)] + + # Outputs are currently of the structure (grouped by device) + # [[output0_device0, output1_device0, output2_device0], + # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] - output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync - last_step_tensor_outputs = [ - last_step_tensor_outputs[i::output_num] for i in range(output_num) - ] + last_step_tensor_outputs = [list(x) for x in + zip(*last_step_tensor_outputs)] else: - # no tensors returned. - last_step_tensor_outputs = [] + if isinstance(replicate_outputs, list): + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [ + x for x in replicate_outputs if not isinstance(x, ops.Operation) + ] + + # Outputs are currently of the structure (flattened) + # [output0_device0, output1_device0, output2_device0, + # output0_device1, output1_device1, output2_device1, + # ...] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync + last_step_tensor_outputs = [ + last_step_tensor_outputs[i::output_num] for i in range(output_num) + ] + else: + # no tensors returned. + last_step_tensor_outputs = [] # Convert replicate_outputs to the original dict structure of # last_step_outputs. @@ -423,19 +457,13 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): with _TPUReplicaContext(self._container_strategy()): return fn(*args, **kwargs) - def _initialize(self): - if context.executing_eagerly(): - # TODO(priyag): Add appopriate call here when eager is supported for TPUs. - raise NotImplementedError("Eager mode not supported in TPUStrategy.") - else: - return [] + def _experimental_initialize_system(self): + """Experimental method added to be used by Estimator. - def _finalize(self): - if context.executing_eagerly(): - # TODO(priyag): Add appopriate call here when eager is supported for TPUs. - raise NotImplementedError("Eager mode not supported in TPUStrategy.") - else: - return [] + This is a private method only to be used by Estimator. Other frameworks + should directly be calling `tf.contrib.distribute.initialize_tpu_system` + """ + initialize_tpu_system(self._tpu_cluster_resolver) def _create_variable(self, next_creator, *args, **kwargs): """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" @@ -443,6 +471,9 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. + elif isinstance(colocate_with, numpy_dataset.SingleDevice): + with ops.device(colocate_with.device): + return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device @@ -635,6 +666,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. + + `distribute_dataset` and `make_input_fn_iterator` assume per-replica + batching. + + Returns: + Boolean. + """ return True diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 73efb524b93a367d98395d4e83ac4bf136318a27..51c58b0b2f3dc2ab63e22718825a471b8657f892 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -22,28 +22,20 @@ import os from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.data.experimental.ops import batching -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util -from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import saver as saver_lib -from tensorflow.python.util import nest class DistributedValuesTest(test.TestCase): @@ -354,444 +346,6 @@ class RegroupAndSelectDeviceTest(test.TestCase): merged_estimator_spec)) -class PerReplicaDatasetTest(test.TestCase): - - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - - def _test_iterator(self, devices, dataset, expected_values): - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map) - per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0) - if context.executing_eagerly(): - iterator = per_replica_dataset.make_one_shot_iterator() - else: - iterator = per_replica_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) - - for expected_value in expected_values: - next_element = iterator.get_next_as_list() - computed_value = self.evaluate(next_element) - self.assertEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next_as_list() - self.evaluate(next_element) - - @test_util.run_in_graph_and_eager_modes - def testOneDevice(self): - devices = ["/device:CPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleDevices(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testTupleDataset(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testUnevenDatasetBatches(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(devices, dataset, expected_values) - - def testInitializableIterator(self): - with context.graph_mode(): - devices = ["/device:CPU:0"] - # Using random input since that is only allowed with initializable - # iterator. - dataset = dataset_ops.Dataset.from_tensor_slices( - random_ops.random_uniform((10,))) - - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map) - per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0) - iterator = per_replica_dataset.make_initializable_iterator() - - self.evaluate(iterator.initializer) - next_element = iterator.get_next_as_list() - for _ in range(10): - self.evaluate(next_element) - - # Should fail after the input is finished. - with self.assertRaises(errors.OutOfRangeError): - self.evaluate(next_element) - - # After re-initializing the iterator, should be able to iterate again. - self.evaluate(iterator.initializer) - for _ in range(10): - self.evaluate(next_element) - - -class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): - - def _test_iterator(self, sess, iterator, devices, expected_values): - next_element = iterator.get_next() - for r, device in enumerate(devices): - v = values.select_replica(r, next_element) - # The `v` here can be a tuple. - for element in nest.flatten(v): - self.assertTrue(element.device in device) - - for expected_value in expected_values: - t = [values.select_replica(r, next_element) for r in range(len(devices))] - actual = sess.run(t) - self.assertEqual(expected_value, actual) - - with self.assertRaises(errors.OutOfRangeError): - sess.run([values.select_replica(r, next_element) - for r in range(len(devices))]) - - def _test_dataset(self, dataset_fn, worker_devices, devices, - expected_values): - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map, worker_devices) - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, input_workers) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - with self.cached_session() as sess: - sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, expected_values) - - def _cpu_devices(self): - worker_devices = ( - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"]) - ) - devices = [ - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def _cpu_and_one_gpu_devices(self): - worker_devices = ( - ("/job:worker/replica:0/task:0", ( - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - )), - ("/job:worker/replica:0/task:1", ( - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - )) - ) - devices = [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def testDataDistributionOneDevicePerWorker(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset( - dataset_fn, worker_devices, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - def testDataDistributionTwoDevicePerWorker(self): - if context.num_gpus() < 1: - self.skipTest("A GPU is not available for this test.") - worker_devices, devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset( - dataset_fn, worker_devices, devices, - [[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]]) - - def testTupleDataset(self): - worker_devices, devices = self._cpu_devices() - - with context.graph_mode(): - - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(8) - dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i, i**2)] for i in range(8)] - self._test_dataset(dataset_fn, worker_devices, devices, - expected_values) - - def testInitializableIterator(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(8) - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map, worker_devices) - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, input_workers) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - - sess.run(multi_worker_iterator.initializer) - self._test_iterator( - sess, multi_worker_iterator, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - # After re-initializing the iterator, should be able to iterate again. - sess.run(multi_worker_iterator.initializer) - self._test_iterator( - sess, multi_worker_iterator, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - def testValueErrorForIterator(self): - # Incompatiable arguments. - d1 = "/device:GPU:0" - d2 = "/device:GPU:1" - device_map = values.ReplicaDeviceMap([d1, d2]) - input_workers = values.InputWorkers( - device_map, (("w1", (d1,)), ("w2", (d2,)))) - with self.assertRaises(ValueError): - values.MultiWorkerDataIterator([("w1", None)], input_workers) - - def testDuplicateDevices(self): - _, devices = self._cpu_devices() - devices.append("/job:worker/replica:0/task:0/device:CPU:0") - with self.assertRaises(ValueError): - _ = values.ReplicaDeviceMap(devices) - - -class InputIteratorTestBase(test.TestCase): - - def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, split_batch_by=None): - devices = nest.flatten([ds for _, ds in worker_device_pairs]) - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map, worker_device_pairs) - - if input_type == "input_fn": - input_contexts = [ - distribute_lib.InputContext() for _ in worker_device_pairs] - input_fn = lambda _: dataset_fn() - iterator = values.InputFunctionIterator( - input_fn, input_workers, input_contexts) - else: - iterator = values.DatasetIterator( - dataset_fn(), input_workers, split_batch_by) - - evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) - - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - evaluate([values.select_replica(r, next_element) - for r in range(len(devices))]) - - # After re-initializing the iterator, should be able to iterate again. - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) - - -class InputIteratorSingleWorkerTest(InputIteratorTestBase, - parameterized.TestCase): - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"])) - def testOneDeviceCPU(self, input_type): - worker_device_pairs = [("", ["/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesOneGPUOneCPU(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTupleDataset(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testUnevenDatasetBatches(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["dataset"], - split_batch_by=[None, 2], - required_gpus=1)) - def testBatchSplitting(self, input_type, split_batch_by): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - batch_size = 10 - dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) - - updated_batch_size = ( - batch_size // split_batch_by if split_batch_by else batch_size) - expected_values = [[range(i, i+updated_batch_size), - range(i+updated_batch_size, i+2*updated_batch_size)] - for i in range(0, 100, updated_batch_size*2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, - split_batch_by=split_batch_by) - - -class InputIteratorMultiWorkerTest( - multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, - parameterized.TestCase): - - def _cpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])] - - def _cpu_and_one_gpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - ]), - ("/job:worker/replica:0/task:1", [ - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ]) - ] - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testOneDevicePerWorker(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 0], [1, 1], [2, 2], [3, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesPerWorker(self, input_type): - worker_devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 1, 0, 1], [2, 3, 2, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testTupleDataset(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(4) - dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess) - - -class SplitDatasetBatchTest(test.TestCase): - - def testBatchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20) - split_batch_by = 2 - result_dataset = values._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testMapAndBatchDataset(self): - dataset = dataset_ops.Dataset.range(100) - dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) - split_batch_by = 2 - result_dataset = values._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testPrefetchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) - split_batch_by = 2 - result_dataset = values._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - class MirroredVariableTest(test.TestCase, parameterized.TestCase): config = config_pb2.ConfigProto() diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 15776c694e92825895437a4c1547699f6d9269fb..9b5a2c947b153308c83f1a922d06c034ec5f9ddf 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -128,7 +128,7 @@ class PTBModel(tf.keras.Model): self.linear = layers.Dense( vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1)) - self._output_shape = [-1, embedding_dim] + self._output_shape = [-1, hidden_dim] def call(self, input_seq, training): """Run the forward pass of PTBModel. diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 4c1d1a29f20b5574b63cf87ecf62db95f92902cd..4e29e2559986012d8eeeaec807f14181226363aa 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -6,7 +6,7 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( name = "feature_column_py", @@ -37,13 +37,13 @@ py_library( ], ) -py_test( +tf_py_test( name = "sequence_feature_column_test", srcs = ["python/feature_column/sequence_feature_column_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":sequence_feature_column", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", @@ -53,17 +53,14 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python/feature_column:feature_column_py", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], + tags = ["no_pip"], ) -py_test( +tf_py_test( name = "sequence_feature_column_integration_test", srcs = ["python/feature_column/sequence_feature_column_integration_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":sequence_feature_column", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", @@ -73,6 +70,7 @@ py_test( "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras:layers", ], + tags = ["no_pip"], ) py_library( @@ -94,14 +92,14 @@ py_library( ], ) -py_test( +tf_py_test( name = "sequence_feature_column_v2_test", srcs = ["python/feature_column/sequence_feature_column_v2_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":sequence_feature_column", ":sequence_feature_column_v2", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", @@ -112,7 +110,6 @@ py_test( "//tensorflow/python:training", "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/feature_column:feature_column_v2_test", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], + tags = ["no_pip"], ) diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index dad50a3a73085526f65bd87c3d8549ceb75b3af4..88a14a2a94cc683f021d032ea11358e0cfb63faa 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -50,6 +50,7 @@ tf_custom_op_py_library( visibility = [ "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", + "//tensorflow_estimator:__subpackages__", "//video/youtube/personalization:__subpackages__", ], deps = [ diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index ae8320cfb279ac3a5cf7cb17ea2d82da79dbe432..db0868fb2c43464a811b3d6dfcd96480ba2463ee 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -1,8 +1,7 @@ -# Files for using TFGAN framework. +# Files for using TF-GAN framework. load("//tensorflow:tensorflow.bzl", "py_test") package(default_visibility = [ - "//learning/brain/contrib/tfgan:__subpackages__", "//tensorflow:__subpackages__", ]) @@ -107,6 +106,7 @@ py_library( deps = [ ":gan_estimator", ":head", + ":latent_gan_estimator", ":stargan_estimator", ":tpu_gan_estimator", "//tensorflow/python:util", @@ -132,6 +132,7 @@ py_library( ":clip_weights", ":conditioning_utils", ":random_tensor_pool", + ":spectral_normalization", ":virtual_batchnorm", "//tensorflow/python:util", ], @@ -145,16 +146,15 @@ py_library( "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:clip_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", + "//tensorflow/python:gradients_impl", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:summary", "//tensorflow/python:tensor_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/losses", - "//third_party/py/numpy", ], ) @@ -570,22 +570,18 @@ py_test( deps = [ ":namedtuples", ":stargan_estimator", - ":tuple_losses", "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/learn", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", - "//tensorflow/python:parsing_ops", "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", "@six_archive//:six", @@ -647,6 +643,41 @@ py_test( ], ) +py_library( + name = "latent_gan_estimator", + srcs = [ + "python/estimator/python/latent_gan_estimator.py", + "python/estimator/python/latent_gan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":train", + "//tensorflow/python:clip_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:random_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "latent_gan_estimator_test", + srcs = [ + "python/estimator/python/latent_gan_estimator_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":latent_gan_estimator", + "//tensorflow/python:array_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/ops/losses", + ], +) + py_library( name = "sliced_wasserstein", srcs = [ @@ -681,3 +712,45 @@ py_test( "//third_party/py/numpy", ], ) + +py_library( + name = "spectral_normalization", + srcs = [ + "python/features/python/spectral_normalization.py", + "python/features/python/spectral_normalization_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:standard_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/keras:engine", + ], +) + +py_test( + name = "spectral_normalization_test", + srcs = ["python/features/python/spectral_normalization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":spectral_normalization", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/slim", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/keras:layers", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 9ab86329eaf0e6fd426aef1f552f4e27c2ad65de..db7dc51daa78ecee12ecb7f6d33df4511e068243 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -1,14 +1,15 @@ -# TensorFlow-GAN (TFGAN) + +# TensorFlow-GAN (TF-GAN) -TFGAN is a lightweight library for training and evaluating Generative +TF-GAN is a lightweight library for training and evaluating Generative Adversarial Networks (GANs). This technique allows you to train a network (called the 'generator') to sample from a distribution, without having to explicitly model the distribution and without writing an explicit loss. For example, the generator could learn to draw samples from the distribution of natural images. For more details on this technique, see ['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by -Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an +Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](http://https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an introduction. #### Usage @@ -17,27 +18,27 @@ import tensorflow as tf tfgan = tf.contrib.gan ``` -## Why TFGAN? +## Why TF-GAN? * Easily train generator and discriminator networks with well-tested, flexible [library calls](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py). You can -mix TFGAN, native TF, and other custom frameworks +mix TF-GAN, native TF, and other custom frameworks * Use already implemented [GAN losses and penalties](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/losses_impl.py) (ex Wasserstein loss, gradient penalty, mutual information penalty, etc) * [Monitor and visualize](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/summaries_impl.py) GAN progress during training, and [evaluate](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py) them * Use already-implemented [tricks](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/) to stabilize and improve training * Develop based on examples of [common GAN setups](https://github.com/tensorflow/models/tree/master/research/gan/) -* Use the TFGAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model -* Improvements in TFGAN infrastructure will automatically benefit your TFGAN project +* Use the TF-GAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model +* Improvements in TF-GAN infrastructure will automatically benefit your TF-GAN project * Stay up-to-date with research as we add more algorithms -## What are the TFGAN components? +## What are the TF-GAN components? -TFGAN is composed of several parts which were design to exist independently. +TF-GAN is composed of several parts which were design to exist independently. These include the following main pieces (explained in detail below). * [core](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py): provides the main infrastructure needed to train a GAN. Training occurs in four phases, and each phase can be completed by custom-code or by using a - TFGAN library call. + TF-GAN library call. * [features](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/): Many common GAN operations and normalization techniques are implemented for @@ -56,14 +57,15 @@ These include the following main pieces (explained in detail below). generative models. * [examples](https://github.com/tensorflow/models/tree/master/research/gan/) - and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TFGAN to make - GAN training easier, or use the more complicated examples to jumpstart your - own project. These include unconditional and conditional GANs, InfoGANs, - adversarial losses on existing networks, and image-to-image translation. + and [tutorial](http://https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TF-GAN + to make GAN training easier, or use the more complicated examples to + jumpstart your own project. These include unconditional and conditional + GANs, InfoGANs, adversarial losses on existing networks, and image-to-image + translation. ## Training a GAN model -Training in TFGAN typically consists of the following steps: +Training in TF-GAN typically consists of the following steps: 1. Specify the input to your networks. 1. Set up your generator and discriminator using a `GANModel`. @@ -71,12 +73,12 @@ Training in TFGAN typically consists of the following steps: 1. Create your train ops using a `GANTrainOps`. 1. Run your train ops. -At each stage, you can either use TFGAN's convenience functions, or you can +At each stage, you can either use TF-GAN's convenience functions, or you can perform the step manually for fine-grained control. We provide examples below. There are various types of GAN setups. For instance, you can train a generator to sample unconditionally from a learned distribution, or you can condition on -extra information such as a class label. TFGAN is compatible with many setups, +extra information such as a class label. TF-GAN is compatible with many setups, and we demonstrate a few below: ### Examples @@ -254,9 +256,9 @@ with variable_scope.variable_scope(dis_scope, reuse=True): discriminator_real_outputs = discriminator_fn(images) generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope) -# Depending on what TFGAN features you use, you don't always need to supply +# Depending on what TF-GAN features you use, you don't always need to supply # every `GANModel` field. At a minimum, you need to include the discriminator -# outputs and variables if you want to use TFGAN to construct losses. +# outputs and variables if you want to use TF-GAN to construct losses. gan_model = tfgan.GANModel( generator_inputs, generated_data, diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index f1946c7f925660eae3aaa650c437e03da1f33d6c..1e6000898f7b8a53ad3f6fa12deebd54bf3a57ff 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN is a lightweight library for training and evaluating GANs. +"""TF-GAN is a lightweight library for training and evaluating GANs. In addition to providing the infrastructure for easily training and evaluating GANS, this library contains modules for a TFGAN-backed Estimator, @@ -24,7 +24,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# Collapse TFGAN into a tiered namespace. +# Collapse TF-GAN into a tiered namespace. from tensorflow.contrib.gan.python import estimator from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin from tensorflow.contrib.gan.python import features diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index 75cccb5ea004c11f250050fedf7475dec6d3b699..430266555b723e6ca39dccffc1442dbef5d4a385 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN estimator module. +"""TF-GAN estimator module. GANEstimator provides all the infrastructure support of a TensorFlow Estimator -with the feature support of TFGAN. +with the feature support of TF-GAN. """ from __future__ import absolute_import @@ -26,11 +26,13 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.estimator.python import gan_estimator from tensorflow.contrib.gan.python.estimator.python import head +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator from tensorflow.contrib.gan.python.estimator.python import stargan_estimator from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * from tensorflow.contrib.gan.python.estimator.python.head import * +from tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator import * from tensorflow.contrib.gan.python.estimator.python.stargan_estimator import * from tensorflow.contrib.gan.python.estimator.python.tpu_gan_estimator import * # pylint: enable=unused-import,wildcard-import @@ -41,7 +43,8 @@ _allowed_symbols = ([ 'gan_estimator', 'stargan_estimator', 'tpu_gan_estimator', + 'latent_gan_estimator', 'head', ] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__ + - tpu_gan_estimator.__all__) + tpu_gan_estimator.__all__ + latent_gan_estimator.__all__) remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index adb72228217892fffc10b0e2630edcd9d3e38a02..dd904611d1a3bb78de8316d5ed29ab0f800f29a9 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed GAN Estimator.""" +"""A TF-GAN-backed GAN Estimator.""" from __future__ import absolute_import from __future__ import division @@ -56,10 +56,10 @@ _summary_type_map = { class GANEstimator(estimator.Estimator): """An estimator for Generative Adversarial Networks (GANs). - This Estimator is backed by TFGAN. The network functions follow the TFGAN API - except for one exception: if either `generator_fn` or `discriminator_fn` have - an argument called `mode`, then the tf.Estimator mode is passed in for that - argument. This helps with operations like batch normalization, which have + This Estimator is backed by TF-GAN. The network functions follow the TF-GAN + API except for one exception: if either `generator_fn` or `discriminator_fn` + have an argument called `mode`, then the tf.Estimator mode is passed in for + that argument. This helps with operations like batch normalization, which have different train and evaluation behavior. Example: @@ -68,7 +68,7 @@ class GANEstimator(estimator.Estimator): import tensorflow as tf tfgan = tf.contrib.gan - # See TFGAN's `train.py` for a description of the generator and + # See TF-GAN's `train.py` for a description of the generator and # discriminator API. def generator_fn(generator_inputs): ... @@ -123,13 +123,13 @@ class GANEstimator(estimator.Estimator): to continue training a previously saved model. generator_fn: A python function that takes a Tensor, Tensor list, or Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TFGAN` for more details and examples. Additionally, if + generator. See `TF-GAN` for more details and examples. Additionally, if it has an argument called `mode`, the Estimator's `mode` will be passed in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch normalization. discriminator_fn: A python function that takes the output of `generator_fn` or real data in the GAN setup, and `generator_inputs`. - Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details + Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details and examples. generator_loss_fn: The loss function on the generator. Takes a `GANModel` tuple. diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 5a3d29cf0b3cb1bbe03cb5ba4f327caf46432b76..5b9c54e43a16adf457d5ed0e7e73dcd168ab0d67 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's estimator.py.""" +"""Tests for TF-GAN's estimator.py.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 1a0ee6dfc498eb6dc8c97411589d9e35bc352062..cbe990b476c3b17ce61e0826b17d10976fea43c7 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed GAN Estimator.""" +"""A TF-GAN-backed GAN Estimator.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 8205bc889dc01c8680e2139393d65723280cfbd0..5b50234a0e33cd297b176f142b358338966b6758 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's head.py.""" +"""Tests for TF-GAN's head.py.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..4e164e24168bb0cc5e9a7cc772081781ea088bb1 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `Train Input Estimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = latent_gan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..f5afc7731937ed1a82c8ebb5969b2687ffdd583b --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py @@ -0,0 +1,205 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements an estimator wrapper that allows training the input latent space. + +This file implements a latent gan estimator that wraps around a previously +trained GAN. The latent gan estimator trains a single variable z, representing +the hidden latent distribution that is the 'noise' input to the GAN. By training +z, the inpainting estimator can move around the latent z space towards +minimizing a specific loss function. + +The latent gan estimator has a few key differences from a normal estimator. + +First: the variables in the estimator should not be saved, as we are not +updating the original GAN and are only adding a new z variable that is meant +to be different for each run. In order to do distributed training using +train_and_evaluate, the Tensorflow RunConfig is expected to save checkpoints +by having either save_checkpoints_steps or save_checkpoints_secs saved. +To avoid this conflict, we purposely set the save_checkpoints_steps value in +the RunConfig to be one step more than the total number of steps that the +inpainter estimator will run. + +Second: we need to specify warm start settings, as we are reloading the +GAN model into a different graph (specifically, one with a new z variable). +The warm start settings defined below reload all GAN variables and ignore the +new z variable (and the optimizer). + +Usage: + + def _generator(net, mode): + ... + + def _discriminator(net, condition, mode): + ... + + def _loss(gan_model, features, labels, add_summaries): + ... + + def optimizer(): + ... + + params = {} + config = tf.estimator.RunConfig() + tmp_dir = path/to/output/storage + + estimator = latent_gan_estimator.get_latent_gan_estimator( + _generator, _discriminator, _loss, optimizer, params, config, tmp_dir) + + def input_fn(): + ... + + estimator.train(input_fn=input_fn) + +See latent_gan_estimator_test.py or tensorflow_models/gan/face_inpainting for +further examples. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import training_util + + +INPUT_NAME = 'new_var_z_input' # The name for the new z space input variable. +OPTIMIZER_NAME = 'latent_gan_optimizer' # The name for the new optimizer vars. + +__all__ = [ + 'get_latent_gan_estimator', +] + + +def _get_latent_gan_model_fn(generator_fn, discriminator_fn, loss_fn, + optimizer): + """Sets up a model function that wraps around a given GAN.""" + def model_fn(features, labels, mode, params): + """Model function defining an inpainting estimator.""" + batch_size = params['batch_size'] + z_shape = [batch_size] + params['z_shape'] + add_summaries = params['add_summaries'] + input_clip = params['input_clip'] + + z = variable_scope.get_variable( + name=INPUT_NAME, initializer=random_ops.truncated_normal(z_shape), + constraint=lambda x: clip_ops.clip_by_value(x, -input_clip, input_clip)) + + generator = functools.partial(generator_fn, mode=mode) + discriminator = functools.partial(discriminator_fn, mode=mode) + gan_model = tfgan_train.gan_model(generator_fn=generator, + discriminator_fn=discriminator, + real_data=labels, + generator_inputs=z, + check_shapes=False) + + loss = loss_fn(gan_model, features, labels, add_summaries) + + # Use a variable scope to make sure that estimator variables dont cause + # save/load problems when restoring from ckpts. + with variable_scope.variable_scope(OPTIMIZER_NAME): + opt = optimizer(learning_rate=params['learning_rate'], + **params['opt_kwargs']) + train_op = opt.minimize( + loss=loss, global_step=training_util.get_or_create_global_step(), + var_list=[z]) + + if add_summaries: + z_grads = gradients_impl.gradients(loss, z) + summary.scalar('z_loss/z_grads', clip_ops.global_norm(z_grads)) + summary.scalar('z_loss/loss', loss) + + return model_fn_lib.EstimatorSpec(mode=mode, + predictions=gan_model.generated_data, + loss=loss, + train_op=train_op) + return model_fn + + +def get_latent_gan_estimator(generator_fn, discriminator_fn, loss_fn, + optimizer, params, config, ckpt_dir, + warmstart_options=True): + """Gets an estimator that passes gradients to the input. + + This function takes in a generator and adds a trainable z variable that is + used as input to this generator_fn. The generator itself is treated as a black + box through which gradients can pass through without updating any weights. The + result is a trainable way to traverse the GAN latent space. The loss_fn is + used to actually train the z variable. The generator_fn and discriminator_fn + should be previously trained by the tfgan library (on reload, the variables + are expected to follow the tfgan format. It may be possible to use the + latent gan estimator with entirely custom GANs that do not use the tfgan + library as long as the appropriate variables are wired properly). + + Args: + generator_fn: a function defining a Tensorflow graph for a GAN generator. + The weights defined in this graph should already be defined in the given + checkpoint location. Should have 'mode' as an argument. + discriminator_fn: a function defining a Tensorflow graph for a GAN + discriminator. Should have 'mode' as an argument. + loss_fn: a function defining a Tensorflow graph for a GAN loss. Takes in a + GANModel tuple, features, labels, and add_summaries as inputs. + optimizer: a tf.Optimizer or a function that returns a tf.Optimizer with no + inputs. + params: An object containing the following parameters: + - batch_size: an int indicating the size of the training batch. + - z_shape: the desired shape of the input z values (not counting batch). + - learning_rate: a scalar or function defining a learning rate applied to + optimizer. + - input_clip: the amount to clip the x training variable by. + - add_summaries: whether or not to add summaries. + - opt_kwargs: optimizer kwargs. + config: tf.RunConfig. Should point model to output dir and should indicate + whether to save checkpoints (to avoid saving checkpoints, set + save_checkpoints_steps to a number larger than the number of train steps). + The model_dir field in the RunConfig should point to a directory WITHOUT + any saved checkpoints. + ckpt_dir: the directory where the model checkpoints live. The checkpoint is + used to warm start the underlying GAN. This should NOT be the same as + config.model_dir. + warmstart_options: boolean, None, or a WarmStartSettings object. If set to + True, uses a default WarmStartSettings object. If set to False or None, + does not use warm start. If using a custom WarmStartSettings object, make + sure that new variables are properly accounted for when reloading the + underlying GAN. Defaults to True. + Returns: + An estimator spec defining a GAN input training estimator. + """ + model_fn = _get_latent_gan_model_fn(generator_fn, discriminator_fn, + loss_fn, optimizer) + + if isinstance(warmstart_options, estimator.WarmStartSettings): + ws = warmstart_options + elif warmstart_options: + # Default WarmStart loads all variable names except INPUT_NAME and + # OPTIMIZER_NAME. + var_regex = '^(?!.*(%s|%s).*)' % (INPUT_NAME, OPTIMIZER_NAME) + ws = estimator.WarmStartSettings(ckpt_to_initialize_from=ckpt_dir, + vars_to_warm_start=var_regex) + else: + ws = None + + if 'opt_kwargs' not in params: + params['opt_kwargs'] = {} + + return estimator.Estimator(model_fn=model_fn, config=config, params=params, + warm_start_from=ws) diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac139e532e35f7aae6da0655103a7249fe3382d4 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py @@ -0,0 +1,119 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for latent_gan_estimator. + +See g3.tp.tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import numpy as np +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator +from tensorflow.python.estimator import run_config as run_config +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +class TrainInputEstimatorTest(test.TestCase): + + def test_get_input_training_estimator(self): + """Integration test to make sure the input_training_estimator works.""" + + # Create dummy test input tensors. + true_features = np.reshape(np.random.uniform(size=100), (10, 10)) + true_labels = np.reshape(np.random.uniform(size=100), (5, 20)) + expected_z_output = [[1, -1], [-1, 1]] + + # Fill out required parameters randomly, includes optimizer kwargs. + params = { + 'batch_size': 2, + 'z_shape': [2], + 'learning_rate': 1.0, + 'input_clip': 1.0, + 'add_summaries': False, + 'opt_kwargs': { + 'beta1': 0.1 + } + } + + input_z_shape = [params['batch_size']] + params['z_shape'] + + # Create dummy model functions that represent an underlying GANEstimator and + # the input training wrapper. Make sure that everything is wired up + # correctly in the internals of each dummy function. + def _generator(net, mode): + """The generator function will get the newly created z variable.""" + del mode + self.assertSequenceEqual(net.shape, input_z_shape) + gen_dummy_var = variable_scope.get_variable( + name='generator_dummy_variable', + initializer=array_ops.ones(input_z_shape)) + return net * gen_dummy_var + + def _discriminator(net, condition, mode): + """The discriminator function will get either the z variable or labels.""" + del condition, mode + try: + self.assertSequenceEqual(net.shape, true_labels.shape) + except AssertionError: + self.assertSequenceEqual(net.shape, input_z_shape) + return net + + def _loss(gan_model, features, labels, _): + """Make sure that features and labels are passed in from input.""" + self.assertTrue(np.array_equal(features, true_features)) + self.assertTrue(np.array_equal(labels, true_labels)) + return losses.absolute_difference(expected_z_output, + gan_model.generated_data) + + optimizer = training.AdamOptimizer + + # We are not loading checkpoints, so set the corresponding directory to a + # dummy directories. + tmp_dir = tempfile.mkdtemp() + config = run_config.RunConfig(model_dir=tmp_dir, + save_summary_steps=None, + save_checkpoints_steps=1, + save_checkpoints_secs=None) + + # Get the estimator. Disable warm start so that there is no attempted + # checkpoint reloading. + estimator = latent_gan_estimator.get_latent_gan_estimator( + _generator, _discriminator, _loss, optimizer, params, config, tmp_dir, + warmstart_options=None) + + # Train for a few steps. + def dummy_input(): + return true_features, true_labels + estimator.train(input_fn=dummy_input, steps=10) + + # Make sure the generator variables did not change, but the z variables did + # change. + self.assertTrue(np.array_equal( + estimator.get_variable_value('Generator/generator_dummy_variable'), + np.ones(input_z_shape))) + self.assertTrue(np.array_equal( + estimator.get_variable_value('new_var_z_input'), + expected_z_output)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py index f60e16bc04662b33bc0bb22b5acc8c7fcc7a03ba..2a485e7d47ff10cf34c1b44f8dcc6b1f33c9a05f 100644 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed StarGAN Estimator.""" +"""A TF-GAN-backed StarGAN Estimator.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py index 2ec7938c7c4051842c7e982b54c1213b6e841b79..c00ff4399748a77f88d9753df7592bf3859d754e 100644 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's stargan_estimator.py.""" +"""Tests for TF-GAN's stargan_estimator.py.""" from __future__ import absolute_import from __future__ import division @@ -80,7 +80,7 @@ class StarGetGANModelTest(test.TestCase, parameterized.TestCase): self.assertEqual(input_data, gan_model.input_data) self.assertIsNotNone(gan_model.generated_data) self.assertIsNotNone(gan_model.generated_data_domain_target) - self.assertEqual(1, len(gan_model.generator_variables)) + self.assertLen(gan_model.generator_variables, 1) self.assertIsNotNone(gan_model.generator_scope) self.assertIsNotNone(gan_model.generator_fn) if mode == model_fn_lib.ModeKeys.PREDICT: @@ -109,7 +109,7 @@ class StarGetGANModelTest(test.TestCase, parameterized.TestCase): gan_model.discriminator_input_data_domain_predication) self.assertIsNotNone( gan_model.discriminator_generated_data_domain_predication) - self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer + self.assertLen(gan_model.discriminator_variables, 2) # 1 FC layer self.assertIsNotNone(gan_model.discriminator_scope) self.assertIsNotNone(gan_model.discriminator_fn) @@ -163,6 +163,7 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): @classmethod def setUpClass(cls): + super(GetEstimatorSpecTest, cls).setUpClass() cls._generator_optimizer = training.GradientDescentOptimizer(1.0) cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py index 295d1382e2cefc6c6e016416dc1f4cddd4f3a46a..8f2a22c78a304c7cc66ef069a235483e9279b3b2 100644 --- a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed GAN Estimator that works on TPU.""" +"""A TF-GAN-backed GAN Estimator that works on TPU.""" from __future__ import absolute_import from __future__ import division @@ -294,9 +294,10 @@ def _get_estimator_spec( gan_model, gan_loss, gan_loss_no_reduction, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: gan_loss = tfgan_tuples.GANLoss( - generator_loss=generator_loss_fn(gan_model, add_summaries=False), + generator_loss=generator_loss_fn( + gan_model, add_summaries=not is_on_tpu), discriminator_loss=discriminator_loss_fn( - gan_model, add_summaries=False)) + gan_model, add_summaries=not is_on_tpu)) # Construct optimizers if arguments were callable. For TPUs, they must be # `CrossShardOptimizer`. diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py index 0a08b4386f88d0480effae63f511cf8140dd22f8..9fdcc08334d50b4ddf3a0bc9bc755e55d51b0bd8 100644 --- a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's TPU Estimator.""" +"""Tests for TF-GAN's TPU Estimator.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/__init__.py b/tensorflow/contrib/gan/python/eval/__init__.py index f86b8513053a45f9830411f7df2c32d1f36a97b2..92e9abf8a35de1999eb800e169f32220fe47f8cd 100644 --- a/tensorflow/contrib/gan/python/eval/__init__.py +++ b/tensorflow/contrib/gan/python/eval/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN evaluation module. +"""TF-GAN evaluation module. This module supports techniques such as Inception Score, Frechet Inception distance, and Sliced Wasserstein distance. diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py index 1c872626a957279132772ae27df7a66a2564e9a5..a52e899114b62cb29752f72aa59f142f4a428aa1 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN.""" +"""Model evaluation tools for TF-GAN.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index ea55241b34314eaa64993a7a3855b9cb2f5922b9..31f0d34ed68a6adc25cca102236079d0f66615cb 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN. +"""Model evaluation tools for TF-GAN. These methods come from https://arxiv.org/abs/1606.03498, https://arxiv.org/abs/1706.08500, and https://arxiv.org/abs/1801.01401. @@ -795,9 +795,9 @@ def kernel_classifier_distance(real_images, on a classifier. num_classifier_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. - max_estimator_block_size: integer, default 1024. The distance estimator - splits samples into blocks for computational efficiency. Larger values are - more computationally expensive but decrease the variance of the distance + max_block_size: integer, default 1024. The distance estimator splits samples + into blocks for computational efficiency. Larger values are more + computationally expensive but decrease the variance of the distance estimate. dtype: if not None, coerce activations to this dtype before computations. @@ -872,9 +872,9 @@ def kernel_classifier_distance_and_std(real_images, on a classifier. num_classifier_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. - max_estimator_block_size: integer, default 1024. The distance estimator - splits samples into blocks for computational efficiency. Larger values are - more computationally expensive but decrease the variance of the distance + max_block_size: integer, default 1024. The distance estimator splits samples + into blocks for computational efficiency. Larger values are more + computationally expensive but decrease the variance of the distance estimate. Having a smaller block size also gives a better estimate of the standard error. dtype: if not None, coerce activations to this dtype before computations. @@ -911,7 +911,7 @@ def kernel_classifier_distance_and_std(real_images, gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) return kernel_classifier_distance_and_std_from_activations( - real_a, gen_a, max_block_size=max_block_size) + real_a, gen_a, max_block_size, dtype) kernel_inception_distance_and_std = functools.partial( @@ -968,14 +968,14 @@ def kernel_classifier_distance_from_activations(real_activations, into blocks for computational efficiency. Larger values are more computationally expensive but decrease the variance of the distance estimate. - dtype: if not None, coerce activations to this dtype before computations. + dtype: If not None, coerce activations to this dtype before computations. Returns: The Kernel Inception Distance. A floating-point scalar of the same type as the output of the activations. """ return kernel_classifier_distance_and_std_from_activations( - real_activations, generated_activations, max_block_size=max_block_size)[0] + real_activations, generated_activations, max_block_size, dtype)[0] def kernel_classifier_distance_and_std_from_activations(real_activations, @@ -1030,7 +1030,7 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, computationally expensive but decrease the variance of the distance estimate. Having a smaller block size also gives a better estimate of the standard error. - dtype: if not None, coerce activations to this dtype before computations. + dtype: If not None, coerce activations to this dtype before computations. Returns: The Kernel Inception Distance. A floating-point scalar of the same type @@ -1081,7 +1081,7 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, dim = math_ops.cast(real_activations.shape[1], dtype) def compute_kid_block(i): - 'Compute the ith block of the KID estimate.' + """Computes the ith block of the KID estimate.""" r_s = inds_r[i] r_e = inds_r[i + 1] r = real_activations[r_s:r_e] diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index dbff1d2a367e10adc607dafb4c571bb3607a3963..bd17571a0535a3c8e9dfee24a8da16eb2e72f165 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN classifier_metrics.""" +"""Tests for TF-GAN classifier_metrics.""" from __future__ import absolute_import from __future__ import division @@ -234,7 +234,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): else: logits = classifier_metrics.run_inception(img, _get_dummy_graphdef()) - self.assertTrue(isinstance(logits, ops.Tensor)) + self.assertIsInstance(logits, ops.Tensor) logits.shape.assert_is_compatible_with([batch_size, 1001]) # Check that none of the model variables are trainable. @@ -258,7 +258,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): img, _get_dummy_graphdef(), output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) - self.assertTrue(isinstance(pool, ops.Tensor)) + self.assertIsInstance(pool, ops.Tensor) pool.shape.assert_is_compatible_with([batch_size, 2048]) # Check that none of the model variables are trainable. @@ -276,8 +276,8 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): classifier_metrics.INCEPTION_FINAL_POOL ]) - self.assertTrue(isinstance(logits, ops.Tensor)) - self.assertTrue(isinstance(pool, ops.Tensor)) + self.assertIsInstance(logits, ops.Tensor) + self.assertIsInstance(pool, ops.Tensor) logits.shape.assert_is_compatible_with([batch_size, 1001]) pool.shape.assert_is_compatible_with([batch_size, 2048]) @@ -290,7 +290,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): classifier_metrics.inception_score, array_ops.zeros([6, 299, 299, 3]), num_batches=3) - self.assertTrue(isinstance(score, ops.Tensor)) + self.assertIsInstance(score, ops.Tensor) score.shape.assert_has_rank(0) # Check that none of the model variables are trainable. @@ -302,7 +302,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): distance = _run_with_mock( classifier_metrics.frechet_inception_distance, img, img) - self.assertTrue(isinstance(distance, ops.Tensor)) + self.assertIsInstance(distance, ops.Tensor) distance.shape.assert_has_rank(0) # Check that none of the model variables are trainable. @@ -314,7 +314,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): distance = _run_with_mock(classifier_metrics.kernel_inception_distance, img, img) - self.assertTrue(isinstance(distance, ops.Tensor)) + self.assertIsInstance(distance, ops.Tensor) distance.shape.assert_has_rank(0) # Check that none of the model variables are trainable. diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py index 523968bed91f1021ae629bf52c405cf5c2d7b917..326fcb3cdbf2eda66207f134cd2926f09a216a99 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN.""" +"""Model evaluation tools for TF-GAN.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/summaries.py b/tensorflow/contrib/gan/python/eval/python/summaries.py index ecfdb39499b1e824e02415c0db1de3157e4f3216..1b202dfc97304ddc7ced42d65366aaf419439392 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common TFGAN summaries.""" +"""Common TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index f9995bb19d0d09eaf6fd96d039b0bba1d3a7055c..9f448d3a1602c503093214201bdc96fc9bee85b5 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common TFGAN summaries.""" +"""Common TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 54a6f8d4d9086ad7fc8db31032677628561e48e8..53fc7cb8ede698c2d8590c7fd3016a884cef9be9 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN summaries.""" +"""Tests for TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py index 4816daf760143af9f1502873b123ffad8e5ec8ce..410c3a02052cd3a07a36a0ba332a80b3c2705d89 100644 --- a/tensorflow/contrib/gan/python/features/__init__.py +++ b/tensorflow/contrib/gan/python/features/__init__.py @@ -27,11 +27,13 @@ from __future__ import print_function from tensorflow.contrib.gan.python.features.python import clip_weights from tensorflow.contrib.gan.python.features.python import conditioning_utils from tensorflow.contrib.gan.python.features.python import random_tensor_pool +from tensorflow.contrib.gan.python.features.python import spectral_normalization from tensorflow.contrib.gan.python.features.python import virtual_batchnorm from tensorflow.contrib.gan.python.features.python.clip_weights import * from tensorflow.contrib.gan.python.features.python.conditioning_utils import * from tensorflow.contrib.gan.python.features.python.random_tensor_pool import * +from tensorflow.contrib.gan.python.features.python.spectral_normalization import * from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * # pylint: enable=unused-import,wildcard-import @@ -40,5 +42,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = clip_weights.__all__ _allowed_symbols += conditioning_utils.__all__ _allowed_symbols += random_tensor_pool.__all__ +_allowed_symbols += spectral_normalization.__all__ _allowed_symbols += virtual_batchnorm.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..54d3d0a218dec3588844333cd47e1f92489d8df9 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-like layers and utilities that implement Spectral Normalization. + +Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, +et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.features.python.spectral_normalization_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = spectral_normalization_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc653f0a7907f407e66add5537d1e0a5adb6d8b --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py @@ -0,0 +1,315 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-like layers and utilities that implement Spectral Normalization. + +Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, +et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import numbers +import re + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import base_layer_utils as keras_base_layer_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging + +__all__ = [ + 'compute_spectral_norm', 'spectral_normalize', 'spectral_norm_regularizer', + 'spectral_normalization_custom_getter', 'keras_spectral_normalization' +] + +# tf.bfloat16 should work, but tf.matmul converts those to tf.float32 which then +# can't directly be assigned back to the tf.bfloat16 variable. +_OK_DTYPES_FOR_SPECTRAL_NORM = (dtypes.float16, dtypes.float32, dtypes.float64) +_PERSISTED_U_VARIABLE_SUFFIX = 'spectral_norm_u' + + +def compute_spectral_norm(w_tensor, power_iteration_rounds=1, name=None): + """Estimates the largest singular value in the weight tensor. + + Args: + w_tensor: The weight matrix whose spectral norm should be computed. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + name: An optional scope name. + + Returns: + The largest singular value (the spectral norm) of w. + """ + with variable_scope.variable_scope(name, 'spectral_norm'): + # The paper says to flatten convnet kernel weights from + # (C_out, C_in, KH, KW) to (C_out, C_in * KH * KW). But TensorFlow's Conv2D + # kernel weight shape is (KH, KW, C_in, C_out), so it should be reshaped to + # (KH * KW * C_in, C_out), and similarly for other layers that put output + # channels as last dimension. + # n.b. this means that w here is equivalent to w.T in the paper. + w = array_ops.reshape(w_tensor, (-1, w_tensor.get_shape()[-1])) + + # Persisted approximation of first left singular vector of matrix `w`. + u_var = variable_scope.get_variable( + _PERSISTED_U_VARIABLE_SUFFIX, + shape=(w.shape[0], 1), + dtype=w.dtype, + initializer=init_ops.random_normal_initializer(), + trainable=False) + u = u_var + + # Use power iteration method to approximate spectral norm. + for _ in range(power_iteration_rounds): + # `v` approximates the first right singular vector of matrix `w`. + v = nn.l2_normalize(math_ops.matmul(array_ops.transpose(w), u)) + u = nn.l2_normalize(math_ops.matmul(w, v)) + + # Update persisted approximation. + with ops.control_dependencies([u_var.assign(u, name='update_u')]): + u = array_ops.identity(u) + + u = array_ops.stop_gradient(u) + v = array_ops.stop_gradient(v) + + # Largest singular value of `w`. + spectral_norm = math_ops.matmul( + math_ops.matmul(array_ops.transpose(u), w), v) + spectral_norm.shape.assert_is_fully_defined() + spectral_norm.shape.assert_is_compatible_with([1, 1]) + + return spectral_norm[0][0] + + +def spectral_normalize(w, power_iteration_rounds=1, name=None): + """Normalizes a weight matrix by its spectral norm. + + Args: + w: The weight matrix to be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + name: An optional scope name. + + Returns: + A normalized weight matrix tensor. + """ + with variable_scope.variable_scope(name, 'spectral_normalize'): + w_normalized = w / compute_spectral_norm( + w, power_iteration_rounds=power_iteration_rounds) + return array_ops.reshape(w_normalized, w.get_shape()) + + +def spectral_norm_regularizer(scale, power_iteration_rounds=1, scope=None): + """Returns a functions that can be used to apply spectral norm regularization. + + Small spectral norms enforce a small Lipschitz constant, which is necessary + for Wasserstein GANs. + + Args: + scale: A scalar multiplier. 0.0 disables the regularizer. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + scope: An optional scope name. + + Returns: + A function with the signature `sn(weights)` that applies spectral norm + regularization. + + Raises: + ValueError: If scale is negative or if scale is not a float. + """ + if isinstance(scale, numbers.Integral): + raise ValueError('scale cannot be an integer: %s' % scale) + if isinstance(scale, numbers.Real): + if scale < 0.0: + raise ValueError( + 'Setting a scale less than 0 on a regularizer: %g' % scale) + if scale == 0.0: + logging.info('Scale of 0 disables regularizer.') + return lambda _: None + + def sn(weights, name=None): + """Applies spectral norm regularization to weights.""" + with ops.name_scope(scope, 'SpectralNormRegularizer', [weights]) as name: + scale_t = ops.convert_to_tensor( + scale, dtype=weights.dtype.base_dtype, name='scale') + return math_ops.multiply( + scale_t, + compute_spectral_norm( + weights, power_iteration_rounds=power_iteration_rounds), + name=name) + + return sn + + +def _default_name_filter(name): + """A filter function to identify common names of weight variables. + + Args: + name: The variable name. + + Returns: + Whether `name` is a standard name for a weight/kernel variables used in the + Keras, tf.layers, tf.contrib.layers or tf.contrib.slim libraries. + """ + match = re.match(r'(.*\/)?(depthwise_|pointwise_)?(weights|kernel)$', name) + return match is not None + + +def spectral_normalization_custom_getter(name_filter=_default_name_filter, + power_iteration_rounds=1): + """Custom getter that performs Spectral Normalization on a weight tensor. + + Specifically it divides the weight tensor by its largest singular value. This + is intended to stabilize GAN training, by making the discriminator satisfy a + local 1-Lipschitz constraint. + + Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan]. + + [sn-gan]: https://openreview.net/forum?id=B1QRgziT- + + To reproduce an SN-GAN, apply this custom_getter to every weight tensor of + your discriminator. The last dimension of the weight tensor must be the number + of output channels. + + Apply this to layers by supplying this as the `custom_getter` of a + `tf.variable_scope`. For example: + + with tf.variable_scope('discriminator', + custom_getter=spectral_norm_getter()): + net = discriminator_fn(net) + + IMPORTANT: Keras does not respect the custom_getter supplied by the + VariableScope, so Keras users should use `keras_spectral_normalization` + instead of (or in addition to) this approach. + + It is important to carefully select to which weights you want to apply + Spectral Normalization. In general you want to normalize the kernels of + convolution and dense layers, but you do not want to normalize biases. You + also want to avoid normalizing batch normalization (and similar) variables, + but in general such layers play poorly with Spectral Normalization, since the + gamma can cancel out the normalization in other layers. By default we supply a + filter that matches the kernel variable names of the dense and convolution + layers of the tf.layers, tf.contrib.layers, tf.keras and tf.contrib.slim + libraries. If you are using anything else you'll need a custom `name_filter`. + + This custom getter internally creates a variable used to compute the spectral + norm by power iteration. It will update every time the variable is accessed, + which means the normalized discriminator weights may change slightly whilst + training the generator. Whilst unusual, this matches how the paper's authors + implement it, and in general additional rounds of power iteration can't hurt. + + Args: + name_filter: Optionally, a method that takes a Variable name as input and + returns whether this Variable should be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform per step. A higher number yeilds a better approximation of the + true spectral norm. + + Returns: + A custom getter function that applies Spectral Normalization to all + Variables whose names match `name_filter`. + + Raises: + ValueError: If name_filter is not callable. + """ + if not callable(name_filter): + raise ValueError('name_filter must be callable') + + def _internal_getter(getter, name, *args, **kwargs): + """A custom getter function that applies Spectral Normalization. + + Args: + getter: The true getter to call. + name: Name of new/existing variable, in the same format as + tf.get_variable. + *args: Other positional arguments, in the same format as tf.get_variable. + **kwargs: Keyword arguments, in the same format as tf.get_variable. + + Returns: + The return value of `getter(name, *args, **kwargs)`, spectrally + normalized. + + Raises: + ValueError: If used incorrectly, or if `dtype` is not supported. + """ + if not name_filter(name): + return getter(name, *args, **kwargs) + + if name.endswith(_PERSISTED_U_VARIABLE_SUFFIX): + raise ValueError( + 'Cannot apply Spectral Normalization to internal variables created ' + 'for Spectral Normalization. Tried to normalized variable [%s]' % + name) + + if kwargs['dtype'] not in _OK_DTYPES_FOR_SPECTRAL_NORM: + raise ValueError('Disallowed data type {}'.format(kwargs['dtype'])) + + # This layer's weight Variable/PartitionedVariable. + w_tensor = getter(name, *args, **kwargs) + + if len(w_tensor.get_shape()) < 2: + raise ValueError( + 'Spectral norm can only be applied to multi-dimensional tensors') + + return spectral_normalize( + w_tensor, + power_iteration_rounds=power_iteration_rounds, + name=(name + '/spectral_normalize')) + + return _internal_getter + + +@contextlib.contextmanager +def keras_spectral_normalization(name_filter=_default_name_filter, + power_iteration_rounds=1): + """A context manager that enables Spectral Normalization for Keras. + + Keras doesn't respect the `custom_getter` in the VariableScope, so this is a + bit of a hack to make things work. + + Usage: + with keras_spectral_normalization(): + net = discriminator_fn(net) + + Args: + name_filter: Optionally, a method that takes a Variable name as input and + returns whether this Variable should be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform per step. A higher number yeilds a better approximation of the + true spectral norm. + + Yields: + A context manager that wraps the standard Keras variable creation method + with the `spectral_normalization_custom_getter`. + """ + original_make_variable = keras_base_layer_utils.make_variable + sn_getter = spectral_normalization_custom_getter( + name_filter=name_filter, power_iteration_rounds=power_iteration_rounds) + + def make_variable_wrapper(name, *args, **kwargs): + return sn_getter(original_make_variable, name, *args, **kwargs) + + keras_base_layer_utils.make_variable = make_variable_wrapper + + yield + + keras_base_layer_utils.make_variable = original_make_variable diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea21f70ec01950cfef5e4fa851c78b219d6062f --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py @@ -0,0 +1,354 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for features.spectral_normalization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import slim +from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl as spectral_normalization +from tensorflow.contrib.layers.python.layers import layers as contrib_layers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.layers import convolutional as keras_convolutional +from tensorflow.python.keras.layers import core as keras_core +from tensorflow.python.layers import convolutional as layers_convolutional +from tensorflow.python.layers import core as layers_core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class SpectralNormalizationTest(test.TestCase): + + def testComputeSpectralNorm(self): + weights = variable_scope.get_variable( + 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) + weights = math_ops.multiply(weights, 10.0) + s = linalg_ops.svd( + array_ops.reshape(weights, [-1, weights.shape[-1]]), compute_uv=False) + true_sn = s[..., 0] + estimated_sn = spectral_normalization.compute_spectral_norm(weights) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + np_true_sn = sess.run(true_sn) + for i in range(50): + est = sess.run(estimated_sn) + if i < 1: + np_est_1 = est + if i < 4: + np_est_5 = est + if i < 9: + np_est_10 = est + np_est_50 = est + + # Check that the estimate improves with more iterations. + self.assertAlmostEqual(np_true_sn, np_est_50, 0) + self.assertGreater( + abs(np_true_sn - np_est_10), abs(np_true_sn - np_est_50)) + self.assertGreater( + abs(np_true_sn - np_est_5), abs(np_true_sn - np_est_10)) + self.assertGreater(abs(np_true_sn - np_est_1), abs(np_true_sn - np_est_5)) + + def testSpectralNormalize(self): + weights = variable_scope.get_variable( + 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) + weights = math_ops.multiply(weights, 10.0) + normalized_weights = spectral_normalization.spectral_normalize( + weights, power_iteration_rounds=1) + + unnormalized_sigma = linalg_ops.svd( + array_ops.reshape(weights, [-1, weights.shape[-1]]), + compute_uv=False)[..., 0] + normalized_sigma = linalg_ops.svd( + array_ops.reshape(normalized_weights, [-1, weights.shape[-1]]), + compute_uv=False)[..., 0] + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + s0 = sess.run(unnormalized_sigma) + + for i in range(50): + sigma = sess.run(normalized_sigma) + if i < 1: + s1 = sigma + if i < 5: + s5 = sigma + if i < 10: + s10 = sigma + s50 = sigma + + self.assertAlmostEqual(1., s50, 0) + self.assertGreater(abs(s10 - 1.), abs(s50 - 1.)) + self.assertGreater(abs(s5 - 1.), abs(s10 - 1.)) + self.assertGreater(abs(s1 - 1.), abs(s5 - 1.)) + self.assertGreater(abs(s0 - 1.), abs(s1 - 1.)) + + def _testLayerHelper(self, build_layer_fn, w_shape, b_shape, is_keras=False): + x = array_ops.placeholder(dtypes.float32, shape=[2, 10, 10, 3]) + + w_initial = np.random.randn(*w_shape) * 10 + w_initializer = init_ops.constant_initializer(w_initial) + b_initial = np.random.randn(*b_shape) + b_initializer = init_ops.constant_initializer(b_initial) + + if is_keras: + context_manager = spectral_normalization.keras_spectral_normalization() + else: + getter = spectral_normalization.spectral_normalization_custom_getter() + context_manager = variable_scope.variable_scope('', custom_getter=getter) + + with context_manager: + (net, + expected_normalized_vars, expected_not_normalized_vars) = build_layer_fn( + x, w_initializer, b_initializer) + + x_data = np.random.rand(*x.shape) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + + # Before running a forward pass we still expect the variables values to + # differ from the initial value because of the normalizer. + w_befores = [] + for name, var in expected_normalized_vars.items(): + w_before = sess.run(var) + w_befores.append(w_before) + self.assertFalse( + np.allclose(w_initial, w_before), + msg=('%s appears not to be normalized. Before: %s After: %s' % + (name, w_initial, w_before))) + + # Not true for the unnormalized variables. + for name, var in expected_not_normalized_vars.items(): + b_before = sess.run(var) + self.assertTrue( + np.allclose(b_initial, b_before), + msg=('%s appears to be unexpectedly normalized. ' + 'Before: %s After: %s' % (name, b_initial, b_before))) + + # Run a bunch of forward passes. + for _ in range(1000): + _ = sess.run(net, feed_dict={x: x_data}) + + # We expect this to have improved the estimate of the spectral norm, + # which should have changed the variable values and brought them close + # to the true Spectral Normalized values. + _, s, _ = np.linalg.svd(w_initial.reshape([-1, 3])) + exactly_normalized = w_initial / s[0] + for w_before, (name, var) in zip(w_befores, + expected_normalized_vars.items()): + w_after = sess.run(var) + self.assertFalse( + np.allclose(w_before, w_after, rtol=1e-8, atol=1e-8), + msg=('%s did not improve over many iterations. ' + 'Before: %s After: %s' % (name, w_before, w_after))) + self.assertAllClose( + exactly_normalized, + w_after, + rtol=1e-4, + atol=1e-4, + msg=('Estimate of spectral norm for %s was innacurate. ' + 'Normalized matrices do not match.' + 'Estimate: %s Actual: %s' % (name, w_after, + exactly_normalized))) + + def testConv2D_Layers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + layer = layers_convolutional.Conv2D( + filters=3, + kernel_size=3, + padding='same', + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'tf.layers.Conv2d.kernel': layer.kernel} + expected_not_normalized_vars = {'tf.layers.Conv2d.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_ContribLayers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['CONTRIB_LAYERS_CONV2D_WEIGHTS'], + 'biases': ['CONTRIB_LAYERS_CONV2D_BIASES'] + } + net = contrib_layers.conv2d( + x, + 3, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'contrib.layers.conv2d.weights': weight_vars[0] + } + expected_not_normalized_vars = { + 'contrib.layers.conv2d.bias': bias_vars[0] + } + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_Slim(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['SLIM_CONV2D_WEIGHTS'], + 'biases': ['SLIM_CONV2D_BIASES'] + } + net = slim.conv2d( + x, + 3, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('SLIM_CONV2D_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('SLIM_CONV2D_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = {'slim.conv2d.weights': weight_vars[0]} + expected_not_normalized_vars = {'slim.conv2d.bias': bias_vars[0]} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_Keras(self): + + def build_layer_fn(x, w_initializer, b_initializer): + layer = keras_convolutional.Conv2D( + filters=3, + kernel_size=3, + padding='same', + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'keras.layers.Conv2d.kernel': layer.kernel} + expected_not_normalized_vars = {'keras.layers.Conv2d.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,), is_keras=True) + + def testFC_Layers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + x = layers_core.Flatten()(x) + layer = layers_core.Dense( + units=3, + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'tf.layers.Dense.kernel': layer.kernel} + expected_not_normalized_vars = {'tf.layers.Dense.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_ContribLayers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['CONTRIB_LAYERS_FC_WEIGHTS'], + 'biases': ['CONTRIB_LAYERS_FC_BIASES'] + } + x = contrib_layers.flatten(x) + net = contrib_layers.fully_connected( + x, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('CONTRIB_LAYERS_FC_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('CONTRIB_LAYERS_FC_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'contrib.layers.fully_connected.weights': weight_vars[0] + } + expected_not_normalized_vars = { + 'contrib.layers.fully_connected.bias': bias_vars[0] + } + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_Slim(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['SLIM_FC_WEIGHTS'], + 'biases': ['SLIM_FC_BIASES'] + } + x = slim.flatten(x) + net = slim.fully_connected( + x, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('SLIM_FC_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('SLIM_FC_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'slim.fully_connected.weights': weight_vars[0] + } + expected_not_normalized_vars = {'slim.fully_connected.bias': bias_vars[0]} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_Keras(self): + + def build_layer_fn(x, w_initializer, b_initializer): + x = keras_core.Flatten()(x) + layer = keras_core.Dense( + units=3, + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'keras.layers.Dense.kernel': layer.kernel} + expected_not_normalized_vars = {'keras.layers.Dense.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,), is_keras=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index a0a86c6337eefa756a209635faa70db686a36247..1f1ae2df4d6def618e86aced3296ac89c836eab7 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -28,7 +28,7 @@ wasserstein_gradient_penalty All losses must be able to accept 1D or 2D Tensors, so as to be compatible with patchGAN style losses (https://arxiv.org/abs/1611.07004). -To make these losses usable in the TFGAN framework, please create a tuple +To make these losses usable in the TF-GAN framework, please create a tuple version of the losses with `losses_utils.py`. """ @@ -38,6 +38,7 @@ from __future__ import print_function from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -69,6 +70,10 @@ __all__ = [ ] +def _to_float(tensor): + return math_ops.cast(tensor, dtypes.float32) + + # Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875). def wasserstein_generator_loss( discriminator_gen_outputs, @@ -98,7 +103,7 @@ def wasserstein_generator_loss( """ with ops.name_scope(scope, 'generator_wasserstein_loss', ( discriminator_gen_outputs, weights)) as scope: - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) loss = - discriminator_gen_outputs loss = losses.compute_weighted_loss( @@ -144,8 +149,8 @@ def wasserstein_discriminator_loss( with ops.name_scope(scope, 'discriminator_wasserstein_loss', ( discriminator_real_outputs, discriminator_gen_outputs, real_weights, generated_weights)) as scope: - discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs) - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_real_outputs = _to_float(discriminator_real_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) discriminator_real_outputs.shape.assert_is_compatible_with( discriminator_gen_outputs.shape) @@ -320,7 +325,7 @@ def wasserstein_gradient_penalty( generated_data: Output of the generator. generator_inputs: Exact argument to pass to the generator, which is used as optional conditioning to the discriminator. - discriminator_fn: A discriminator function that conforms to TFGAN API. + discriminator_fn: A discriminator function that conforms to TF-GAN API. discriminator_scope: If not `None`, reuse discriminators from this scope. epsilon: A small positive number added for numerical stability when computing the gradient norm. @@ -647,7 +652,7 @@ def least_squares_generator_loss( """ with ops.name_scope(scope, 'lsq_generator_loss', (discriminator_gen_outputs, real_label)) as scope: - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) loss = math_ops.squared_difference( discriminator_gen_outputs, real_label) / 2.0 loss = losses.compute_weighted_loss( @@ -702,8 +707,8 @@ def least_squares_discriminator_loss( """ with ops.name_scope(scope, 'lsq_discriminator_loss', (discriminator_gen_outputs, real_label)) as scope: - discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs) - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_real_outputs = _to_float(discriminator_real_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) discriminator_real_outputs.shape.assert_is_compatible_with( discriminator_gen_outputs.shape) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index 221c70c38bd432a6be7f6cda9c6700aa2255821f..76e57df7f646547037b3461ac44f7ee5b971406c 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN utilities for loss functions that accept GANModel namedtuples. +"""TF-GAN utilities for loss functions that accept GANModel namedtuples. The losses and penalties in this file all correspond to losses in `losses_impl.py`. Losses in that file take individual arguments, whereas in this diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 969b68449d9c82f9f9144a8657cd8932b38fd0f7..73dfee4fdeec87cf0bac5eb675fd02a64a9ad7f5 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Named tuples for TFGAN. +"""Named tuples for TF-GAN. -TFGAN training occurs in four steps, and each step communicates with the next -step via one of these named tuples. At each step, you can either use a TFGAN +TF-GAN training occurs in four steps, and each step communicates with the next +step via one of these named tuples. At each step, you can either use a TF-GAN helper function in `train.py`, or you can manually construct a tuple. """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 4c7bee41b33ce1fee46d374ca5fd1c0b603762f9..f36a5d346e0f27fbbc480e876380db51ed559c09 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The TFGAN project provides a lightweight GAN training/testing framework. +"""The TF-GAN project provides a lightweight GAN training/testing framework. This file contains the core helper functions to create and train a GAN model. See the README or examples in `tensorflow_models` for details on how to use. -TFGAN training occurs in four steps: +TF-GAN training occurs in four steps: 1) Create a model 2) Add a loss 3) Create train ops @@ -645,9 +645,10 @@ def gan_loss( type(model)) # Optionally create pooled model. - pooled_model = ( - _tensor_pool_adjusted_model(model, tensor_pool_fn) - if tensor_pool_fn else model) + if tensor_pool_fn: + pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn) + else: + pooled_model = model # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) @@ -665,10 +666,11 @@ def gan_loss( if _use_aux_loss(mutual_information_penalty_weight): gen_info_loss = tfgan_losses.mutual_information_penalty( model, add_summaries=add_summaries) - dis_info_loss = ( - gen_info_loss - if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty( - pooled_model, add_summaries=add_summaries)) + if tensor_pool_fn is None: + dis_info_loss = gen_info_loss + else: + dis_info_loss = tfgan_losses.mutual_information_penalty( + pooled_model, add_summaries=add_summaries) gen_loss += mutual_information_penalty_weight * gen_info_loss dis_loss += mutual_information_penalty_weight * dis_info_loss if _use_aux_loss(aux_cond_generator_weight): @@ -929,7 +931,7 @@ def gan_train_ops( **kwargs): """Returns GAN train ops. - The highest-level call in TFGAN. It is composed of functions that can also + The highest-level call in TF-GAN. It is composed of functions that can also be called, should a user require more control over some part of the GAN training process. diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index 4ef0a66a52429233c6e6f70667a451466493629c..294a7d69a704b3c06ab9e30489af116929ab6c2a 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -34,7 +34,7 @@ def sparse_multiclass_hinge_loss( scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS): - """Adds Ops for computing the multiclass hinge loss. + r"""Adds Ops for computing the multiclass hinge loss. The implementation is based on the following paper: On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 42b91d031375b8edb7e4f364ac91ffb74ef1f54b..19daffea6c7e4486499388314d0aaaa611e94218 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,3 +1,3 @@ # K-FAC: Kronecker-Factored Approximate Curvature -## KFAC moved to third_party/tensorflow_kfac. +## KFAC moved to https://github.com/tensorflow/kfac. diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py index 11033a2e9cb646c2e7cd2f45de1f751d88c6921a..76b03ff514821d3459f84c5f46a64d1134e0d4de 100644 --- a/tensorflow/contrib/layers/python/layers/normalization.py +++ b/tensorflow/contrib/layers/python/layers/normalization.py @@ -186,7 +186,7 @@ def group_norm(inputs, Args: inputs: A Tensor with at least 2 dimensions one which is channels. All - shape dimensions must be fully defined. + shape dimensions except for batch must be fully defined. groups: Integer. Divide the channels into this number of groups over which normalization statistics are computed. This number must be commensurate with the number of channels in `inputs`. @@ -249,13 +249,21 @@ def group_norm(inputs, """ # TODO(shlens): Support partially defined shapes for the inputs. inputs = ops.convert_to_tensor(inputs) - original_shape = inputs.shape if inputs.shape.ndims is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) if channels_axis > (inputs.shape.ndims - 1): raise ValueError('Axis is out of bounds.') + # Use dynamic shape for not fully defined dimensions in the inputs. + dyanmic_shape = array_ops.shape(inputs) + input_shape_list = [] + for i, dim in enumerate(inputs.shape): + if dim.value is None: + input_shape_list.append(dyanmic_shape[i]) + else: + input_shape_list.append(dim) + # Standardize the channels_axis to be positive and identify # of channels. if channels_axis < 0: channels_axis = inputs.shape.ndims + channels_axis @@ -289,8 +297,8 @@ def group_norm(inputs, # Determine axes before channels. Some examples of common image formats: # 'NCHW': before = [N], after = [HW] # 'NHWC': before = [NHW], after = [] - axes_before_channels = inputs.shape.as_list()[:channels_axis] - axes_after_channels = inputs.shape.as_list()[channels_axis+1:] + axes_before_channels = input_shape_list[:channels_axis] + axes_after_channels = input_shape_list[channels_axis+1:] # Manually broadcast the parameters to conform to the number of groups. params_shape_broadcast = ([1] * len(axes_before_channels) + @@ -369,7 +377,7 @@ def group_norm(inputs, outputs = inputs * gain + offset # Collapse the groups into the channel dimension. - outputs = array_ops.reshape(outputs, original_shape) + outputs = array_ops.reshape(outputs, input_shape_list) if activation_fn is not None: outputs = activation_fn(outputs) diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py index c8d3c91b10dbe3b959e91182f9924b78352d370d..9a85084b239837ade87d8c778393ef8e885f5bdd 100644 --- a/tensorflow/contrib/layers/python/layers/normalization_test.py +++ b/tensorflow/contrib/layers/python/layers/normalization_test.py @@ -221,6 +221,15 @@ class GroupNormTest(test.TestCase): normalization.group_norm(inputs, channels_axis=-1, reduction_axes=[-3, -2]) + def testParamsShapeNotFullyDefinedBatchAxis(self): + height, width, groups = 3, 3, 4 + inputs = array_ops.placeholder(dtypes.float32, + shape=(None, height, width, 2*groups)) + output = normalization.group_norm(inputs, channels_axis=-1, + reduction_axes=[-3, -2], groups=groups) + self.assertListEqual([None, height, width, 2 * groups], + output.shape.as_list()) + def testCreateOp(self): height, width, groups = 3, 3, 4 images = random_ops.random_uniform((5, height, width, 2*groups), seed=1) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py index 5e90d1fa20535de3b5e25bc7ff8c3862cea5514c..318046733bf75a6d661d26f478118c8e944afe15 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py @@ -174,7 +174,7 @@ class GeneratorIoTest(test.TestCase): return np.arange(32, 36) with self.cached_session(): - with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'): + with self.assertRaisesRegexp(TypeError, r'x\(\) must be generator'): failing_input_fn = generator_io.generator_input_fn( generator, batch_size=2, shuffle=False, num_epochs=1) failing_input_fn() @@ -185,7 +185,7 @@ class GeneratorIoTest(test.TestCase): yield np.arange(32, 36) with self.cached_session(): - with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'): + with self.assertRaisesRegexp(TypeError, r'x\(\) must yield dict'): failing_input_fn = generator_io.generator_input_fn( generator, batch_size=2, shuffle=False, num_epochs=1) failing_input_fn() diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD index ecac06354d2ce796f2a6021cdf2370d7c30ccab7..a7be92a35e0d62a61f7923ac61bb2c1267d039c6 100644 --- a/tensorflow/contrib/mpi_collectives/BUILD +++ b/tensorflow/contrib/mpi_collectives/BUILD @@ -52,7 +52,6 @@ tf_custom_op_library( deps = [ ":mpi_defines", ":mpi_message_proto_cc", - "//tensorflow/stream_executor:stream_executor_headers_lib", "//third_party/mpi", ], ) diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py index 3fb649ea82e79b3bc78a2da6d5c3e9a071adec6d..c5c9fc74deaf0171a33d0eb1b5c6f60b3aa5e533 100644 --- a/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py @@ -41,7 +41,7 @@ class AdamGSOptimizer(optimizer.Optimizer): def __init__(self, global_step=0, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="Adam"): - """Construct a new Adam optimizer. + r"""Construct a new Adam optimizer. Branched from tf.train.AdamOptimizer. The only difference is to pass global step for computing beta1 and beta2 accumulators, instead of having @@ -83,23 +83,20 @@ class AdamGSOptimizer(optimizer.Optimizer): Args: global_step: tensorflow variable indicating the step. learning_rate: A Tensor or a floating point value. The learning rate. - beta1: A float value or a constant float tensor. - The exponential decay rate for the 1st moment estimates. - beta2: A float value or a constant float tensor. - The exponential decay rate for the 2nd moment estimates. + beta1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. The exponential decay + rate for the 2nd moment estimates. epsilon: A small constant for numerical stability. This epsilon is "epsilon hat" in the Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm 1 of the paper. use_locking: If True use locks for update operations. name: Optional name for the operations created when applying gradients. - Defaults to "Adam". - - @compatibility(eager) - When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and - `epsilon` can each be a callable that takes no arguments and returns the - actual value to use. This can be useful for changing these values across - different invocations of optimizer functions. - @end_compatibility + Defaults to "Adam". @compatibility(eager) When eager execution is + enabled, `learning_rate`, `beta1`, `beta2`, and `epsilon` can each be a + callable that takes no arguments and returns the actual value to use. + This can be useful for changing these values across different + invocations of optimizer functions. @end_compatibility """ super(AdamGSOptimizer, self).__init__(use_locking, name) self._lr = learning_rate diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py index 248ffb1f7eb5dc27112ddf9b8670344904065ed0..1b7800f324b908e3c88fe90d31a2a08cbbd5ccf2 100644 --- a/tensorflow/contrib/optimizer_v2/adam.py +++ b/tensorflow/contrib/optimizer_v2/adam.py @@ -36,7 +36,7 @@ class AdamOptimizer(optimizer_v2.OptimizerV2): def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="Adam"): - """Construct a new Adam optimizer. + r"""Construct a new Adam optimizer. Initialization: diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 7fb23abc38d9dc101204ed83808aebe5a8ef1e78..1323ed014c9e51e273491694fa44a8e36cc723d0 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -843,8 +843,7 @@ class OptimizerV2(optimizer_v1.Optimizer): scale_loss_by_num_replicas = ( distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN) if scale_loss_by_num_replicas: - num_replicas = \ - distribute_ctx.get_distribution_strategy().num_replicas_in_sync + num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= 1. / num_replicas return loss_value diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 44b232e0f2b26f16f0300e11cf2764e1157a0050..39b688596875ab1b208d97a5d6f9a5ee811674cb 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -227,7 +227,10 @@ tf_custom_op_library( "kernels/lstm_ops_gpu.cu.cc", "kernels/lstm_ops.h", ], - deps = ["//tensorflow/core/kernels:eigen_helpers"], + deps = [ + "//tensorflow/core/kernels:eigen_contraction_kernel", + "//tensorflow/core/kernels:eigen_helpers", + ], ) tf_gen_op_wrapper_py( @@ -249,7 +252,10 @@ tf_custom_op_library( "kernels/gru_ops_gpu.cu.cc", "kernels/gru_ops.h", ], - deps = ["//tensorflow/core/kernels:eigen_helpers"], + deps = [ + "//tensorflow/core/kernels:eigen_contraction_kernel", + "//tensorflow/core/kernels:eigen_helpers", + ], ) tf_gen_op_wrapper_py( @@ -346,6 +352,7 @@ tf_kernel_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//tensorflow/core/kernels:eigen_helpers", "//third_party/eigen3", ], diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.h b/tensorflow/contrib/rnn/kernels/blas_gemm.h index d37210d4b81203287fb633adc309688a35d093bb..12f3182a6a8878aa27ee143fa6405903e3fc4ef3 100644 --- a/tensorflow/contrib/rnn/kernels/blas_gemm.h +++ b/tensorflow/contrib/rnn/kernels/blas_gemm.h @@ -21,6 +21,10 @@ limitations under the License. #include "tensorflow/core/kernels/eigen_activations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + namespace tensorflow { class OpKernelContext; namespace functor { diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 18b56cd21942e28cb0dc3210df0bb04d55c1e16f..89176180ae0dd963bccc34aa2d0fc52be839dd3f 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -33,7 +33,6 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":beam_search_ops", - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/contrib/util:util_py", @@ -59,7 +58,6 @@ tf_custom_op_py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/distributions", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index b7f9f3fb090356a1c8d2bfb5044712ff93e267ce..abcf71c61b6e6df9462bf06323b8b11d5cc0d9a8 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -34,8 +34,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope -from tensorflow.python.ops.distributions import bernoulli -from tensorflow.python.ops.distributions import categorical from tensorflow.python.platform import test # pylint: enable=g-import-not-at-top @@ -517,7 +515,7 @@ class BasicDecoderTest(test.TestCase): vocabulary_size) # The sample function samples categorically from the logits. - sample_fn = lambda x: categorical.Categorical(logits=x).sample() + sample_fn = lambda x: helper_py.categorical_sample(logits=x) # The next inputs are a one-hot encoding of the sampled labels. next_inputs_fn = ( lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32)) @@ -599,7 +597,7 @@ class BasicDecoderTest(test.TestCase): # The sample function samples independent bernoullis from the logits. sample_fn = ( - lambda x: bernoulli.Bernoulli(logits=x, dtype=dtypes.bool).sample()) + lambda x: helper_py.bernoulli_sample(logits=x, dtype=dtypes.bool)) # The next inputs are a one-hot encoding of the sampled labels. next_inputs_fn = math_ops.to_float end_fn = lambda sample_ids: sample_ids[:, end_token] diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py index 5aa32b532ffcf5772f6ace26662f5e5471cf6923..41b2a53ca5b178be9b04446c81d832575e5ed75b 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py @@ -14,80 +14,254 @@ # ============================================================================== """Tests for contrib.seq2seq.python.seq2seq.loss_ops.""" -# pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: enable=unused-import import numpy as np from tensorflow.contrib.seq2seq.python.ops import loss from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class LossTest(test.TestCase): + def setUp(self): + self.batch_size = 2 + self.sequence_length = 3 + self.number_of_classes = 5 + logits = [ + constant_op.constant(i + 0.5, shape=[self.batch_size, + self.number_of_classes]) + for i in range(self.sequence_length) + ] + self.logits = array_ops.stack(logits, axis=1) + targets = [ + constant_op.constant(i, dtypes.int32, shape=[self.batch_size]) + for i in range(self.sequence_length) + ] + self.targets = array_ops.stack(targets, axis=1) + weights = [ + constant_op.constant(1.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + self.weights = array_ops.stack(weights, axis=1) + # expected_loss = sparse_softmax_cross_entropy_with_logits(targets, logits) + # where targets = [0, 1, 2], and logits = [[0.5] * 5, [1.5] * 5, [2.5] * 5] + self.expected_loss = 1.60944 + def testSequenceLoss(self): - with self.session(use_gpu=True) as sess: - with variable_scope.variable_scope( - 'root', initializer=init_ops.constant_initializer(0.5)): - batch_size = 2 - sequence_length = 3 - number_of_classes = 5 - logits = [ - constant_op.constant( - i + 0.5, shape=[batch_size, number_of_classes]) - for i in range(sequence_length) - ] - logits = array_ops.stack(logits, axis=1) - targets = [ - constant_op.constant( - i, dtypes.int32, shape=[batch_size]) - for i in range(sequence_length) - ] - targets = array_ops.stack(targets, axis=1) - weights = [ - constant_op.constant( - 1.0, shape=[batch_size]) for i in range(sequence_length) - ] - weights = array_ops.stack(weights, axis=1) - - average_loss_per_example = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=True, - average_across_batch=True) - res = sess.run(average_loss_per_example) - self.assertAllClose(1.60944, res) - - average_loss_per_sequence = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=False, - average_across_batch=True) - res = sess.run(average_loss_per_sequence) - compare_per_sequence = np.ones((sequence_length)) * 1.60944 - self.assertAllClose(compare_per_sequence, res) - - average_loss_per_batch = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=True, - average_across_batch=False) - res = sess.run(average_loss_per_batch) - compare_per_batch = np.ones((batch_size)) * 1.60944 - self.assertAllClose(compare_per_batch, res) - - total_loss = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=False, - average_across_batch=False) - res = sess.run(total_loss) - compare_total = np.ones((batch_size, sequence_length)) * 1.60944 - self.assertAllClose(compare_total, res) + with self.test_session(use_gpu=True): + average_loss_per_example = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=True, + average_across_batch=True) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + average_loss_per_sequence = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=False, + average_across_batch=True) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + average_loss_per_batch = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=True, + average_across_batch=False) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + total_loss = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=False, + average_across_batch=False) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testSequenceLossClass(self): + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=True, + average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=True, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_batch = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testSumReduction(self): + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=True) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=False) + average_loss_per_batch = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testWeightedSumReduction(self): + weights = [ + constant_op.constant(1.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + # Make the last element in the sequence to have zero weights. + weights[-1] = constant_op.constant(0.0, shape=[self.batch_size]) + self.weights = array_ops.stack(weights, axis=1) + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=True) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + # The last element in every sequence are zeros, which will be filtered. + compare_per_sequence[-1] = 0. + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=False) + average_loss_per_batch = seq_loss(self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss(self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + # The last element in every sequence are zeros, which will be filtered. + compare_total[:, -1] = 0 + self.assertAllClose(compare_total, res) + + def testZeroWeights(self): + weights = [ + constant_op.constant(0.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + weights = array_ops.stack(weights, axis=1) + with self.test_session(use_gpu=True): + average_loss_per_example = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=True, + average_across_batch=True) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(0.0, res) + + average_loss_per_sequence = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=False, + average_across_batch=True) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.zeros((self.sequence_length)) + self.assertAllClose(compare_per_sequence, res) + + average_loss_per_batch = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=True, + average_across_batch=False) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.zeros((self.batch_size)) + self.assertAllClose(compare_per_batch, res) + + total_loss = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=False, + average_across_batch=False) + res = self.evaluate(total_loss) + compare_total = np.zeros((self.batch_size, self.sequence_length)) + self.assertAllClose(compare_total, res) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index 3245cc5e72154289ea3ba000b9a30586a7ad03a9..033c2eb0801d5a51ee937f5e960faa91a6f1ae54 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -32,9 +32,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops.distributions import bernoulli -from tensorflow.python.ops.distributions import categorical from tensorflow.python.util import nest __all__ = [ @@ -51,6 +50,68 @@ __all__ = [ _transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access +# The following sample functions (_call_sampler, bernoulli_sample, +# categorical_sample) mimic TensorFlow Probability distribution semantics. + + +def _call_sampler(sample_n_fn, sample_shape, name=None): + """Reshapes vector of samples.""" + with ops.name_scope(name, "call_sampler", values=[sample_shape]): + sample_shape = ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32, name="sample_shape") + # Ensure sample_shape is a vector (vs just a scalar). + pad = math_ops.cast(math_ops.equal(array_ops.rank(sample_shape), 0), + dtypes.int32) + sample_shape = array_ops.reshape( + sample_shape, + array_ops.pad(array_ops.shape(sample_shape), + paddings=[[pad, 0]], + constant_values=1)) + samples = sample_n_fn(math_ops.reduce_prod(sample_shape)) + batch_event_shape = array_ops.shape(samples)[1:] + final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) + return array_ops.reshape(samples, final_shape) + + +def bernoulli_sample(probs=None, logits=None, dtype=dtypes.int32, + sample_shape=(), seed=None): + """Samples from Bernoulli distribution.""" + if probs is None: + probs = math_ops.sigmoid(logits, name="probs") + else: + probs = ops.convert_to_tensor(probs, name="probs") + batch_shape_tensor = array_ops.shape(probs) + def _sample_n(n): + """Sample vector of Bernoullis.""" + new_shape = array_ops.concat([[n], batch_shape_tensor], 0) + uniform = random_ops.random_uniform( + new_shape, seed=seed, dtype=probs.dtype) + return math_ops.cast(math_ops.less(uniform, probs), dtype) + return _call_sampler(_sample_n, sample_shape) + + +def categorical_sample(logits, dtype=dtypes.int32, + sample_shape=(), seed=None): + """Samples from categorical distribution.""" + logits = ops.convert_to_tensor(logits, name="logits") + event_size = array_ops.shape(logits)[-1] + batch_shape_tensor = array_ops.shape(logits)[:-1] + def _sample_n(n): + """Sample vector of categoricals.""" + if logits.shape.ndims == 2: + logits_2d = logits + else: + logits_2d = array_ops.reshape(logits, [-1, event_size]) + sample_dtype = dtypes.int64 if logits.dtype.size > 4 else dtypes.int32 + draws = random_ops.multinomial( + logits_2d, n, seed=seed, output_dtype=sample_dtype) + draws = array_ops.reshape( + array_ops.transpose(draws), + array_ops.concat([[n], batch_shape_tensor], 0)) + return math_ops.cast(draws, dtype) + return _call_sampler(_sample_n, sample_shape) + + def _unstack_ta(inp): return tensor_array_ops.TensorArray( dtype=inp.dtype, size=array_ops.shape(inp)[0], @@ -307,14 +368,14 @@ class ScheduledEmbeddingTrainingHelper(TrainingHelper): with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample", [time, outputs, state]): # Return -1s where we did not sample, and sample_ids elsewhere - select_sampler = bernoulli.Bernoulli( - probs=self._sampling_probability, dtype=dtypes.bool) - select_sample = select_sampler.sample( - sample_shape=self.batch_size, seed=self._scheduling_seed) - sample_id_sampler = categorical.Categorical(logits=outputs) + select_sample = bernoulli_sample( + probs=self._sampling_probability, + dtype=dtypes.bool, + sample_shape=self.batch_size, + seed=self._scheduling_seed) return array_ops.where( select_sample, - sample_id_sampler.sample(seed=self._seed), + categorical_sample(logits=outputs, seed=self._seed), gen_array_ops.fill([self.batch_size], -1)) def next_inputs(self, time, outputs, state, sample_ids, name=None): @@ -425,8 +486,10 @@ class ScheduledOutputTrainingHelper(TrainingHelper): def sample(self, time, outputs, state, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperSample", [time, outputs, state]): - sampler = bernoulli.Bernoulli(probs=self._sampling_probability) - return sampler.sample(sample_shape=self.batch_size, seed=self._seed) + return bernoulli_sample( + probs=self._sampling_probability, + sample_shape=self.batch_size, + seed=self._seed) def next_inputs(self, time, outputs, state, sample_ids, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs", @@ -610,8 +673,7 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper): else: logits = outputs / self._softmax_temperature - sample_id_sampler = categorical.Categorical(logits=logits) - sample_ids = sample_id_sampler.sample(seed=self._seed) + sample_ids = categorical_sample(logits=logits, seed=self._seed) return sample_ids diff --git a/tensorflow/contrib/seq2seq/python/ops/loss.py b/tensorflow/contrib/seq2seq/python/ops/loss.py index 39a6d2f58b140706a94d83273d3327edd1891368..0fbfd6187030f14ac105a18b3e09b7a42d4de32a 100644 --- a/tensorflow/contrib/seq2seq/python/ops/loss.py +++ b/tensorflow/contrib/seq2seq/python/ops/loss.py @@ -20,11 +20,12 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.keras.losses import Loss from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -__all__ = ["sequence_loss"] +__all__ = ["sequence_loss", "SequenceLoss"] def sequence_loss(logits, @@ -32,16 +33,26 @@ def sequence_loss(logits, weights, average_across_timesteps=True, average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False, softmax_loss_function=None, name=None): """Weighted cross-entropy loss for a sequence of logits. - Depending on the values of `average_across_timesteps` and - `average_across_batch`, the return Tensor will have rank 0, 1, or 2 as these - arguments reduce the cross-entropy at each target, which has shape - `[batch_size, sequence_length]`, over their respective dimensions. For - example, if `average_across_timesteps` is `True` and `average_across_batch` - is `False`, then the return Tensor will have shape `[batch_size]`. + Depending on the values of `average_across_timesteps` / `sum_over_timesteps` + and `average_across_batch` / `sum_over_batch`, the return Tensor will have + rank 0, 1, or 2 as these arguments reduce the cross-entropy at each target, + which has shape `[batch_size, sequence_length]`, over their respective + dimensions. For example, if `average_across_timesteps` is `True` and + `average_across_batch` is `False`, then the return Tensor will have shape + `[batch_size]`. + + Note that `average_across_timesteps` and `sum_over_timesteps` cannot be True + at same time. Same for `average_across_batch` and `sum_over_batch`. + + The recommended loss reduction in tf 2.0 has been changed to sum_over, instead + of weighted average. User are recommend to use `sum_over_timesteps` and + `sum_over_batch` for reduction. Args: logits: A Tensor of shape @@ -58,6 +69,12 @@ def sequence_loss(logits, dimension and divide the cost by the total label weight across timesteps. average_across_batch: If set, sum the cost across the batch dimension and divide the returned cost by the batch size. + sum_over_timesteps: If set, sum the cost across the sequence dimension and + divide the size of the sequence. Note that any element with 0 weights will + be excluded from size calculation. + sum_over_batch: if set, sum the cost across the batch dimension and divide + the total cost by the batch size. Not that any element with 0 weights will + be excluded from size calculation. softmax_loss_function: Function (labels, logits) -> loss-batch to be used instead of the standard softmax (the default if this is None). **Note that to avoid confusion, it is required for the function to accept @@ -78,11 +95,15 @@ def sequence_loss(logits, raise ValueError("Logits must be a " "[batch_size x sequence_length x logits] tensor") if len(targets.get_shape()) != 2: - raise ValueError("Targets must be a [batch_size x sequence_length] " - "tensor") + raise ValueError("Targets must be a [batch_size x sequence_length] tensor") if len(weights.get_shape()) != 2: - raise ValueError("Weights must be a [batch_size x sequence_length] " - "tensor") + raise ValueError("Weights must be a [batch_size x sequence_length] tensor") + if average_across_timesteps and sum_over_timesteps: + raise ValueError("average_across_timesteps and sum_over_timesteps cannot " + "be set to True at same time.") + if average_across_batch and sum_over_batch: + raise ValueError("average_across_batch and sum_over_batch cannot be set " + "to True at same time.") with ops.name_scope(name, "sequence_loss", [logits, targets, weights]): num_classes = array_ops.shape(logits)[2] logits_flat = array_ops.reshape(logits, [-1, num_classes]) @@ -96,20 +117,56 @@ def sequence_loss(logits, if average_across_timesteps and average_across_batch: crossent = math_ops.reduce_sum(crossent) total_size = math_ops.reduce_sum(weights) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size + crossent = math_ops.div_no_nan(crossent, total_size) + elif sum_over_timesteps and sum_over_batch: + crossent = math_ops.reduce_sum(crossent) + total_count = math_ops.cast(math_ops.count_nonzero(weights), + crossent.dtype) + crossent = math_ops.div_no_nan(crossent, total_count) else: - batch_size = array_ops.shape(logits)[0] - sequence_length = array_ops.shape(logits)[1] - crossent = array_ops.reshape(crossent, [batch_size, sequence_length]) - if average_across_timesteps and not average_across_batch: - crossent = math_ops.reduce_sum(crossent, axis=[1]) - total_size = math_ops.reduce_sum(weights, axis=[1]) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size - if not average_across_timesteps and average_across_batch: - crossent = math_ops.reduce_sum(crossent, axis=[0]) - total_size = math_ops.reduce_sum(weights, axis=[0]) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size + crossent = array_ops.reshape(crossent, array_ops.shape(logits)[0:2]) + if average_across_timesteps or average_across_batch: + reduce_axis = [0] if average_across_batch else [1] + crossent = math_ops.reduce_sum(crossent, axis=reduce_axis) + total_size = math_ops.reduce_sum(weights, axis=reduce_axis) + crossent = math_ops.div_no_nan(crossent, total_size) + elif sum_over_timesteps or sum_over_batch: + reduce_axis = [0] if sum_over_batch else [1] + crossent = math_ops.reduce_sum(crossent, axis=reduce_axis) + total_count = math_ops.cast( + math_ops.count_nonzero(weights, axis=reduce_axis), + dtype=crossent.dtype) + crossent = math_ops.div_no_nan(crossent, total_count) return crossent + + +class SequenceLoss(Loss): + """Weighted cross-entropy loss for a sequence of logits.""" + + def __init__(self, + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True, + softmax_loss_function=None, + name=None): + super(SequenceLoss, self).__init__(name=name) + self.average_across_timesteps = average_across_timesteps + self.average_across_batch = average_across_batch + self.sum_over_timesteps = sum_over_timesteps + self.sum_over_batch = sum_over_batch + self.softmax_loss_function = softmax_loss_function + + def __call__(self, y_true, y_pred, sample_weight=None): + """Override the parent __call__ to have a customized reduce behavior.""" + return sequence_loss(y_pred, y_true, sample_weight, + average_across_timesteps=self.average_across_timesteps, + average_across_batch=self.average_across_batch, + sum_over_timesteps=self.sum_over_timesteps, + sum_over_batch=self.sum_over_batch, + softmax_loss_function=self.softmax_loss_function, + name=self.name) + + def call(self, y_true, y_pred): + # Skip this method since the __call__ contains real implementation. + pass diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h index d3edb43733761a906c6e5bf8b65f76e3e1ae56fc..3100a5a0e5da1103b61bd089cd433721686b9e72 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h @@ -32,7 +32,7 @@ class DecisionTreeResource : public ResourceBase { // Constructor. explicit DecisionTreeResource(const TensorForestParams& params); - string DebugString() override { + string DebugString() const override { return strings::StrCat("DecisionTree[size=", decision_tree_->decision_tree().nodes_size(), "]"); } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h index eea0be27caf0a022ba7acaacd359c75a2df4eedb..44f2b3f473b9eced06bd800b9cf0a5a0825ec3eb 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h @@ -40,7 +40,7 @@ class FertileStatsResource : public ResourceBase { model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_); } - string DebugString() override { return "FertileStats"; } + string DebugString() const override { return "FertileStats"; } void ExtractFromProto(const FertileStats& stats); diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index da123e1623f9abc12dd7a44f3d7e4127740a62df..3d34b91a74f9e9f5b211fc60049f0749962d649a 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -491,6 +491,8 @@ cuda_py_tests( "test/binary_tensor_weight_broadcast_test.py", "test/concatenation_test.py", "test/const_broadcast_test.py", + "test/conv2d_test.py", + "test/identity_output_test.py", "test/manual_test.py", "test/memory_alignment_test.py", "test/multi_connection_neighbor_engine_test.py", @@ -549,7 +551,6 @@ cuda_py_test( ], tags = [ "no_cuda_on_cpu_tap", - "no_oss", # TODO(b/121194394): re-enable in OSS after OOM is fixed. "no_pip", "no_tap", # It is not able to download the mnist data. "no_windows", diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index bf2de94e04ae3f6817f7a679ce9fd88e750827dd..eef647473ab12a1425b8d7810bd1f39f8b818dc5 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -334,13 +334,12 @@ struct EdgePtrCompare { tensorflow::Status GetEngineInfo( const tensorflow::Graph* g, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& segment_nodes, + const std::set& segment_nodes, const std::unordered_map& node_map, const std::vector& reverse_topo_order, EngineInfo* info) { - std::vector subgraph_node_ids; // Topologically sorted node ids. - std::set subgraph_node_names = segment_nodes; - std::set added_const_node_ids; // Used to prevent double insertion. + std::vector subgraph_nodes; // Topologically sorted nodes. + std::set added_const_nodes; // Used to prevent double insertion. std::set segment_devices; // Map from src_node_name+port to the unique port numbers of the TRT op, where @@ -352,9 +351,8 @@ tensorflow::Status GetEngineInfo( std::unordered_map input_to_engine_port, output_to_engine_port; for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); ++it) { - const auto& node_name = (*it)->name(); - if (segment_nodes.count(node_name) == 0) continue; - auto node = *it; + const Node* node = *it; + if (segment_nodes.count(node) == 0) continue; auto node_device = node->requested_device(); if (!node_device.empty()) { segment_devices.insert(node_device); @@ -366,8 +364,11 @@ tensorflow::Status GetEngineInfo( << " neither have requested device nor assigned device"; } } + subgraph_nodes.push_back(node); + const int node_id = node->id(); - subgraph_node_ids.push_back(node_id); + const string& node_name = node->name(); + // Create input connections. Sort edges first to make determnistic since // in_edges is a set of pointers. std::vector in_edges(node->in_edges().begin(), @@ -375,7 +376,7 @@ tensorflow::Status GetEngineInfo( std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare()); for (const auto edge : in_edges) { auto input_node = edge->src(); - if (input_node->IsSource() || segment_nodes.count(input_node->name())) { + if (input_node->IsSource() || segment_nodes.count(input_node)) { continue; } if (edge->IsControlEdge()) { @@ -392,12 +393,11 @@ tensorflow::Status GetEngineInfo( // // Note that the segmenter already ensure that the constant data input // is valid and suppported by the engine. - if (!added_const_node_ids.insert(input_node->id()).second) { + if (!added_const_nodes.insert(input_node).second) { // Already added before. continue; } VLOG(1) << "Adding const node " << input_node->name(); - QCHECK(subgraph_node_names.insert(input_node->name()).second); // Since we already add (duplicate) the const input node to the segment // graphdef, it's now not a data dependency any more, but to make the // dependency correct we still add a control dependency. @@ -428,7 +428,7 @@ tensorflow::Status GetEngineInfo( std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare()); for (const auto edge : out_edges) { auto output_node = edge->dst(); - if (output_node->IsSink() || segment_nodes.count(output_node->name())) { + if (output_node->IsSink() || segment_nodes.count(output_node)) { continue; } if (edge->IsControlEdge()) { @@ -456,12 +456,11 @@ tensorflow::Status GetEngineInfo( } // For each segment node in topological order. // Construct the const nodes first. - subgraph_node_ids.insert(subgraph_node_ids.begin(), - added_const_node_ids.begin(), - added_const_node_ids.end()); + subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(), + added_const_nodes.end()); TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( - g, graph_properties, subgraph_node_names, subgraph_node_ids, - &info->connections, &info->segment_graph_def, &info->engine_name)); + g, graph_properties, subgraph_nodes, &info->connections, + &info->segment_graph_def, &info->engine_name)); // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); @@ -1033,27 +1032,31 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { cudaSetDevice(cuda_device_id); auto status = CreateTRTNode(engine_segments, i, params.max_batch_size, &graph, alloc.get(), &engine_nodes); - // If status is ok, we successfully added the node to the graph and can - // remove segment ops. Otherwise graph is not modified. + string msg = StrCat("TensorRT node ", engine.engine_name, " added for segment ", i, " consisting of ", converted_segments.at(i).first.size(), " nodes"); if (status.ok()) { LOG(INFO) << msg << " succeeded."; - for (auto node_name : converted_segments.at(i).first) { - graph.RemoveNode(node_map.at(node_name)); - } } else { // Graph is not modified. LOG(WARNING) << msg << " failed: " << status << ". Fallback to TF..."; } if (VLOG_IS_ON(1)) { msg = "Segment consists of nodes: "; - for (const string& node_name : converted_segments.at(i).first) { - StrAppend(&msg, node_name, ", "); + for (const Node* node : converted_segments.at(i).first) { + StrAppend(&msg, node->name(), ", "); } VLOG(1) << msg; } + + // If status is ok, we successfully added the node to the graph and can + // remove segment ops. Otherwise graph is not modified. + if (status.ok()) { + for (const Node* node : converted_segments.at(i).first) { + graph.RemoveNode(const_cast(node)); + } + } } cudaSetDevice(old_cuda_device); graph.ToGraphDef(params.output_graph_def); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index adf8831b960172fc29b5d631e5b0533318d4764d..7b0c4b446d73e0d01eb3ab0c0431800b2661119b 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -879,6 +879,8 @@ Status Converter::ConvertNode(const NodeDef& node_def) { // We need to check the name before setting it. If the input is one of the // engine input, setting the name here will overwrite engine input // bindings which will cause runtime error. + // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer + // in ConvertIdentity. if (output.is_tensor()) { const char* tensor_name = output.tensor()->getName(); if (!tensorflow::str_util::StartsWith(tensor_name, kInputPHName)) { @@ -939,6 +941,22 @@ Status Converter::RenameAndMarkOutputTensors( if (tensor == nullptr) { return errors::NotFound("Output tensor not found: ", output.first); } + // Check if this tensor has already been marked as an output. + // ConvertIdentity can cause the same tensor to be repeated in + // output_tensors, which can cause us to overwrite the name of the output + // tensor binding. For example, if we rename OutputPH_0 to OutputPH_1 then + // we won't be able to locate OutputPH_0 during runtime. To fix this, + // duplicate the tensor using no-op shuffle. + // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer + // in ConvertIdentity. + if (tensorflow::str_util::StartsWith(tensor->getName(), kOutputPHName)) { + // Using shuffle layer for identity by not setting reshape or transpose. + nvinfer1::IShuffleLayer* layer = network()->addShuffle(*tensor); + TFTRT_RETURN_ERROR_IF_NULLPTR( + layer, StrCat("Output Copy for ", tensor->getName())); + MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); + tensor = layer->getOutput(0); + } tensor->setName(output.second.c_str()); VLOG(1) << "Marking output tensor " << output.first << ", as output tensor " << output.second; @@ -1538,6 +1556,11 @@ enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV }; tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + if (inputs.size() != 2) { + return tensorflow::errors::InvalidArgument("Two inputs are expected for ", + node_def.op(), ", at ", + node_def.name()); + } if (inputs.at(0).is_weights()) { return tensorflow::errors::Unimplemented( node_def.op(), " is only implemented for tensors, not weights, at ", @@ -1549,39 +1572,61 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { node_def.name()); } TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); - VLOG(2) << "weight shape: " << weights_rsck.DebugString(); if (weights_rsck.shape_.nbDims != 4) { - return tensorflow::errors::Internal( - "Conv2D expects kernel of dimension 4, at: " + node_def.name()); + return tensorflow::errors::InvalidArgument( + "Conv2D expects kernel of dimension 4, at " + node_def.name()); } + TFAttrs attrs(node_def); + auto data_format = attrs.get("data_format"); + int c_index = (data_format == "NHWC") ? 3 : 1; + int h_index = (data_format == "NHWC") ? 1 : 2; + int w_index = (data_format == "NHWC") ? 2 : 3; + auto tf_dilations = attrs.get>("dilations"); + if (tf_dilations.size() != 4) { + return tensorflow::errors::InvalidArgument( + "Convolution dilations field must specify 4 dimensions, at ", + node_def.name()); + } + if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) { + return tensorflow::errors::Unimplemented( + "Dilation rate must be 1 for batch and channel dimensions, at ", + node_def.name()); + } + const nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]); + + const auto tf_stride = attrs.get>("strides"); + if (tf_stride.size() != 4) { + return tensorflow::errors::InvalidArgument( + "Convolution strides field must specify 4 dimensions, at ", + node_def.name()); + } + if (tf_stride[0] != 1 || tf_stride[c_index] != 1) { + return tensorflow::errors::Unimplemented( + "Stride must be 1 for batch and channel dimensions, at ", + node_def.name()); + } + const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); if (params->validation_only) return tensorflow::Status::OK(); const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - TFAttrs attrs(node_def); - int h_index = 2; - int w_index = 3; - auto data_format = attrs.get("data_format"); - if (data_format == "NHWC") { + // Transpose to NCHW (NCHW is required for IConvLayer). + const bool need_transpose = (data_format == "NHWC"); + if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( const_cast(tensor), {0, 3, 1, 2}, &tensor)); - h_index = 1; - w_index = 2; - // TODO(jie): transpose it } - - // tensor after transpose (NCHW) + // Dimensions of transposed tensor. const auto tensor_dim = tensor->getDimensions(); - int num_groups = group; - if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution - VLOG(2) << "groups count: " << num_groups; + // For depthwise convolution, group will be 0 so set num_groups to size of + // input's channel dim. For a non-depthwise conv, num_groups will be 1. + const int num_groups = (group == 0) ? tensor_dim.d[0] : group; if (params->converter->precision_mode() == FP16MODE) { weights_rsck = ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); } - TRT_ShapedWeights weights = params->weight_store->GetTempWeights(weights_rsck); ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); @@ -1590,35 +1635,22 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { nvinfer1::DimsHW kernel_size; kernel_size.h() = weights.shape_.d[2]; kernel_size.w() = weights.shape_.d[3]; - VLOG(2) << "RSCK: " << weights.DebugString(); - VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w(); - - // TODO(jie): stride. (NHWC/NCHW) - const auto tf_stride = attrs.get>("strides"); - VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index; - VLOG(2) << "stride: " << tf_stride[0] << tf_stride[1] << tf_stride[2] - << tf_stride[3]; - const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + // Add padding. std::vector> padding; - // TODO(jie): padding. if (attrs.get("padding") == "SAME") { - // This is NCHW tensor with no batch dimension. - // 1 -> h - // 2 -> w + nvinfer1::DimsHW effective_kernel_size = kernel_size; + effective_kernel_size.h() += (kernel_size.h() - 1) * (dilation.h() - 1); + effective_kernel_size.w() += (kernel_size.w() - 1) * (dilation.w() - 1); padding = CreateSamePadding( - stride, kernel_size, + stride, effective_kernel_size, {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); } else { padding = {{0, 0}, {0, 0}}; } - if (padding[0].first != padding[0].second || padding[1].first != padding[1].second) { - // TODO(jie): handle asymmetric padding - VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second - << padding[1].first << padding[1].second; - VLOG(2) << "TENSOR before: " << DebugString(tensor->getDimensions()); + // Handle asymmetric padding. auto pad_layer = params->converter->network()->addPadding( *const_cast(tensor), nvinfer1::DimsHW(padding[0].first, padding[1].first), @@ -1628,24 +1660,23 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { const_cast(tensor), pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); - VLOG(2) << "TENSOR after: " << DebugString(tensor->getDimensions()); } + // Add convolution. nvinfer1::IConvolutionLayer* layer = params->converter->network()->addConvolution( *const_cast(tensor), noutput, kernel_size, weights.GetTrtWeights(), biases.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); + layer->setDilation(dilation); const nvinfer1::ITensor* output_tensor = layer->getOutput(0); - VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions()); - VLOG(2) << "data_format: " << data_format; - if (data_format == "NHWC") { - // TODO(jie): transpose it back! + + // Restore transpose. + if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( const_cast(output_tensor), {0, 2, 3, 1}, &output_tensor)); @@ -3749,8 +3780,7 @@ tensorflow::Status ConvertGraphDefToEngine( tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& subgraph_node_names, - const std::vector& subgraph_node_ids, // In topological order + const std::vector& subgraph_nodes, // In topological order std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope) { std::set marker_nodes; @@ -3824,11 +3854,10 @@ tensorflow::Status ConvertSegmentToGraphDef( std::unordered_map old_to_new_id_map; // Copy internal nodes to new graphdef - string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name(); - for (const auto node_id : subgraph_node_ids) { - const auto node = graph->FindNodeId(node_id); + string local_scope = subgraph_nodes.front()->name(); + for (const Node* node : subgraph_nodes) { local_scope = GetCommonNameScope(local_scope, node->name()); - old_to_new_id_map[node_id] = segment_def->node_size(); + old_to_new_id_map[node->id()] = segment_def->node_size(); auto snode = segment_def->add_node(); snode->CopyFrom(node->def()); VLOG(2) << "Copying " << snode->name() << " to subgraph"; @@ -3846,6 +3875,11 @@ tensorflow::Status ConvertSegmentToGraphDef( << placeholder_name; snode->set_input(connection.inside_port, placeholder_name); } + std::set subgraph_node_names; + for (const Node* node : subgraph_nodes) { + subgraph_node_names.insert(node->name()); + } + // Remove control inputs that are not inside the segment. for (int i = 0; i < segment_def->node_size(); ++i) { auto snode = segment_def->mutable_node(i); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 54e19b73957bccdae2b23bd3556de9ad00b864e5..8f2271ee3f5af0198695c913c6757ef0f6d60d61 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -128,8 +128,7 @@ struct EngineInfo { tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& subgraph_node_names, - const std::vector& subgraph_node_ids, + const std::vector& subgraph_nodes, std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index a2ddfbffa5b0d8c421bcfe054097a9e42b79fe8f..f143f56d2094b2418df0122e6a7e6f5d41136f31 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -2378,6 +2378,8 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { }; { + // Input is weights, should fail. + Reset(); NodeDef node_def = get_strided_slice_nodedef(); AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); AddTestWeights("begin", {4}, {0, 0, 0, 0}); @@ -2619,6 +2621,240 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { } } +TEST_F(OpConverterTest, ConvertConv2D) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_conv2d", "Conv2D", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Two inputs are expected for Conv2D, at my_conv2d"); + } + + // Get nodedef for Conv2D layer. + auto get_conv2d_nodedef = + [](std::vector strides = {1, 1, 1, 1}, string padding = "SAME", + string data_format = "NCHW", + std::vector dilations = {1, 1, 1, 1}) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT); + ops::Conv2D::Attrs attrs = + ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); + auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides, + padding, attrs); + return conv2d.operation.node()->def(); + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Conv2D is only implemented for tensors, not weights, at my_conv2d"); + } + { + // Filter is tensor, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights", {3, 3, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Kernel for Conv2D must be constant weights, at my_conv2d"); + } + { + // Filter is not 4D, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Conv2D expects kernel of dimension 4, at my_conv2d"); + } + { + // Dilations is not 4D, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Convolution dilations field must specify 4 dimensions, at my_conv2d"); + } + { + // Dilation value is not 1 for channel, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation rate must be 1 for batch and channel " + "dimensions, at my_conv2d"); + } + { + // Dilation value is not 1 for channel (NHWC), should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2}); + AddTestTensor("input", {2, 3, 1}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation rate must be 1 for batch and channel " + "dimensions, at my_conv2d"); + } + { + // Strides is not 4D, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Convolution strides field must specify 4 dimensions, at my_conv2d"); + } + { + // Stride value is not 1 for channel, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Stride must be 1 for batch and channel dimensions, at my_conv2d"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, + const std::vector& input, + const std::vector& filter_dims, + const std::vector& filter, + const std::vector& strides, const string& padding, + const string& data_format, const std::vector& dilations, + const std::vector& expected_output_dims, + const std::vector& expected_output) + : input_dims(input_dims), + input(input), + filter_dims(filter_dims), + filter(filter), + strides(strides), + padding(padding), + data_format(data_format), + dilations(dilations), + expected_output_dims(expected_output_dims), + expected_output(expected_output) {} + + std::vector input_dims; + std::vector input; + std::vector filter_dims; + std::vector filter; + std::vector strides; + string padding; + string data_format; + std::vector dilations; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Ok. + const int kConv2DOKCases = 6; + TestParams ok_params[kConv2DOKCases] = { + // Basic + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 1, 0, 1}}, + // SAME padding (Asymmetric) + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 1, -2, 0, 1, -4}}, + // SAME padding (Symmetric) + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 3, 1, 1}, + /*filter=*/{-1, 0, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 2, -1, 3, 1, -3}}, + // NHWC + TestParams{/*input_dims=*/{2, 3, 1}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NHWC", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{2, 2, 1}, + /*expected_output=*/{1, 1, 0, 1}}, + // Dilated + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 2}, + /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{2, 1}}, + // Strided + TestParams{/*input_dims=*/{1, 2, 4}, + /*input=*/{0, 1, 2, 2, 3, 4, 4, 7}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 2}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 0, 1, 3}}, + }; + + for (int i = 0; i < kConv2DOKCases; i++) { + Reset(); + NodeDef node_def = + get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding, + ok_params[i].data_format, ok_params[i].dilations); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", ok_params[i].filter_dims, + ok_params[i].filter); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + std::vector output_data(ok_params[i].expected_output.size()); + BuildAndRun({{"input", ok_params[i].input}}, "my_conv2d", + &output_data); + EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output)); + } +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 203b2697babe32b45523109708cbf062dceee33b..92377edaecc8eb3bda3fcb12abb7aab48feedb2f 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -45,12 +45,19 @@ from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import saver -if _six.PY2: - _to_bytes = lambda s: s - _to_string = lambda s: s -else: - _to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape") - _to_string = lambda s: s.decode("utf-8") + +def _to_bytes(s): + """Encode s if it is a sequence of chars.""" + if isinstance(s, _six.text_type): + return s.encode("utf-8", errors="surrogateescape") + return s + + +def _to_string(s): + """Decode s if it is a sequence of bytes.""" + if isinstance(s, _six.binary_type): + return s.decode("utf-8") + return s class TrtPrecisionMode(object): diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index aac9e5c7bd725fc10bcaa04536ebc7be071b4d4c..8d877b392faa6a6773aea573b23bd02300051f05 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -48,7 +48,7 @@ class TRTCalibrationResource : public tensorflow::ResourceBase { allocator_.reset(); } - string DebugString() override { + string DebugString() const override { std::stringstream oss; using std::dec; using std::endl; diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 084a96e0fa5c97edc58adf2590ed94e5ef0e4d85..ecaffa3023bc8f317d956181b44639bc80efda29 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -673,10 +673,11 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 3 --------------------------------- // Convert the segments into the expected return format for (const auto& itr : sg_map) { - const std::set& segment_nodes = - itr.second; + const string& segment_root = itr.first; + // Return format does not require set comparator. + std::set segment_nodes(itr.second.begin(), itr.second.end()); if (VLOG_IS_ON(1)) { - string s = "parent=" + itr.first + ":"; + string s = "parent=" + segment_root + ":"; for (auto node : segment_nodes) s += " " + node->name(); VLOG(1) << "Segment " << segments->size() << ": " << s; } @@ -689,12 +690,10 @@ tensorflow::Status SegmentGraph( } // TODO(sami): Make segmenter placement aware once trtscopes are in place - std::set segment_node_names; - for (auto node : itr.second) segment_node_names.insert(node->name()); - const auto& dev_itr = device_maps.find(itr.first); + const auto& dev_itr = device_maps.find(segment_root); if (dev_itr == device_maps.end() || dev_itr->second.empty()) { VLOG(1) << "No device assigned to segment " << segments->size(); - segments->emplace_back(std::make_pair(segment_node_names, string())); + segments->emplace_back(std::make_pair(segment_nodes, string())); } else if (dev_itr->second.size() > 1) { string s("Segment "); StrAppend(&s, segments->size(), " has multiple devices attached: "); @@ -703,10 +702,10 @@ tensorflow::Status SegmentGraph( } LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin()); segments->emplace_back( - std::make_pair(segment_node_names, *(dev_itr->second.begin()))); + std::make_pair(segment_nodes, *(dev_itr->second.begin()))); } else { segments->emplace_back( - std::make_pair(segment_node_names, *(dev_itr->second.begin()))); + std::make_pair(segment_nodes, *(dev_itr->second.begin()))); } } if (VLOG_IS_ON(1)) { diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index b9693aad1b764515459db6833b05221ea5b3a2d1..6cc92cdb5df396a6bca26119f152487bc3685a6d 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -29,10 +29,10 @@ namespace tensorflow { namespace tensorrt { namespace segment { -// Vector of segments, each entry contains a set of node names and a device name -// in the segment. -// TODO(aaroey): use node pointer instead of node name. -using SegmentNodesVector = std::vector, string>>; +// Vector of segments, each entry contains a set of node pointers and a device +// name in the segment. +using SegmentNodesVector = + std::vector, string>>; struct SegmentOptions { // Segment must contain at least this many nodes. diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 4805ef9c61a7784a1c08cf5eaf504691bc9dbedc..4ac02327ae68069278066b6e7e931bb9449c2603 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -75,7 +75,10 @@ class SegmentTest : public ::testing::Test { const std::vector>& expected_segments) { EXPECT_EQ(expected_segments.size(), segments.size()); for (int i = 0; i < segments.size(); ++i) { - const auto& segment_node_names = segments[i].first; + std::set segment_node_names; + for (const Node* node : segments[i].first) { + segment_node_names.insert(node->name()); + } const auto& expected = expected_segments[i]; for (const auto& name : expected) { EXPECT_TRUE(segment_node_names.count(name)) diff --git a/tensorflow/contrib/tensorrt/test/conv2d_test.py b/tensorflow/contrib/tensorrt/test/conv2d_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d4cee3cce00c7f17af82f1135c272258cfb66833 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/conv2d_test.py @@ -0,0 +1,191 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.platform import test + + +def conv2d_layer(inputs, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + name=None): + dtype = inputs.dtype + c_axis = -1 if data_format == "channels_last" else 1 + nchan = inputs.shape[c_axis] + weights_shape = (kernel_size[0], kernel_size[1], nchan, filters) + weights = constant_op.constant(np.random.randn(*weights_shape), dtype=dtype) + padding = padding.upper() + if data_format == "channels_last": + strides = [1] + list(strides) + [1] + dilations = [1] + list(dilation_rate) + [1] + data_format = "NHWC" + else: + strides = [1, 1] + list(strides) + dilations = [1, 1] + list(dilation_rate) + data_format = "NCHW" + return gen_nn_ops.conv2d( + inputs, + weights, + strides=strides, + padding=padding, + dilations=dilations, + data_format=data_format) + + +def div_round_up(n, d): + return (n - 1) // d + 1 + + +def build_graph(input_dims, + dtype, + num_filters, + data_format, + kernel_sizes, + dilation_rates, + padding="same"): + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name="input") + with g.device("/GPU:0"): + results = [] + for kernel_size in kernel_sizes: + for dilation_rate in dilation_rates: + result = conv2d_layer(inp, num_filters, kernel_size, (1, 1), padding, + data_format, dilation_rate) + results.append(result) + output = sum(results) + output = array_ops.identity(output, name="output") + return g + + +class Conv2DNCHWTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of Conv2D (data_format=NCHW) in TF-TRT conversion.""" + np.random.seed(1234) + input_dims = [13, 3, 7, 11] + g = build_graph( + input_dims=input_dims, + dtype=dtypes.float32, + num_filters=5, + data_format="channels_first", + kernel_sizes=[(3, 3), (3, 2)], + dilation_rates=[(1, 1), (2, 3)]) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=["input"], + input_dims=[input_dims], + output_names=["output"], + expected_output_dims=[(13, 5, 7, 11)]) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["TRTEngineOp_0"] + + +class Conv2DNHWCTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of Conv2D (data_format=NCHW) in TF-TRT conversion.""" + np.random.seed(1234) + input_dims = [13, 7, 11, 3] + g = build_graph( + input_dims=input_dims, + dtype=dtypes.float32, + num_filters=5, + data_format="channels_last", + kernel_sizes=[(3, 3), (3, 2)], + dilation_rates=[(1, 1), (2, 3)]) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=["input"], + input_dims=[input_dims], + output_names=["output"], + expected_output_dims=[(13, 7, 11, 5)]) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["TRTEngineOp_0"] + + +class Conv2DStridedNCHWTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of strided Conv2D (data_format=NCHW) in TF-TRT + + conversion. + """ + np.random.seed(1234) + dtype = dtypes.float32 + input_name = "input" + n, c, h, w = 13, 3, 7, 11 + num_filters = 5 + input_dims = [n, c, h, w] + output_name = "output" + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + with g.device("/GPU:0"): + output = inp + output = conv2d_layer( + output, + num_filters, (3, 2), + strides=(2, 2), + padding="same", + data_format="channels_first") + h = div_round_up(h, 2) + w = div_round_up(w, 2) + output = conv2d_layer( + output, + num_filters, (3, 3), + strides=(2, 2), + dilation_rate=(2, 3), + padding="same", + data_format="channels_first") + h = div_round_up(h, 2) + w = div_round_up(w, 2) + output = array_ops.identity(output, name=output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + output_names=[output_name], + expected_output_dims=[(n, num_filters, h, w)]) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["TRTEngineOp_0"] + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/identity_output_test.py b/tensorflow/contrib/tensorrt/test/identity_output_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d1c6dd52040f73f4adf8152e19717122798755bb --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/identity_output_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This test checks a situation where the same tensor is considered as an output + +multiple times because it has been duplicated by 2+ indentity ops. Previously, +the tensor would be renamed multiple times, overwriting the output binding name +which resulted in a runtime error when the binding would not be found. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class IdentityTest(trt_test.TfTrtIntegrationTestBase): + + def _ConstOp(self, shape): + return constant_op.constant(np.random.randn(*shape), dtype=dtypes.float32) + + def GetParams(self): + """Testing engine with the same tensor repeated as output via identity.""" + input_name = 'input' + input_dims = [100, 32] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + + b = self._ConstOp((32, 4)) + x1 = math_ops.matmul(x, b) + b = self._ConstOp((1, 4)) + x1 = x1 + b + + out1 = array_ops.identity(x1, name='output1') + out2 = array_ops.identity(x1, name='output2') + iden1 = array_ops.identity(x1) + out3 = array_ops.identity(iden1, name='output3') + + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + output_names=['output1', 'output2', 'output3'], + expected_output_dims=[(100, 4), (100, 4), (100, 4)]) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ['TRTEngineOp_0'] + + def ShouldRunTest(self, run_params): + """Whether to run the test.""" + # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8 + # mode, which is a bug. Re-enable this when trt library is fixed. + return not trt_test.IsQuantizationMode(run_params.precision_mode) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py index e7d6ec4ad395d38a06f97020f2f363009f2286c7..79dde2a6b577acc7664aef37354e3a2d8e408cc2 100644 --- a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py +++ b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py @@ -144,7 +144,10 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): outputs=[OUTPUT_NODE_NAME], max_batch_size=max_batch_size, precision_mode='INT8', - max_workspace_size_bytes=4096 << 19, + # There is a 2GB GPU memory limit for each test, so we set + # max_workspace_size_bytes to 256MB to leave enough room for TF + # runtime to allocate GPU memory. + max_workspace_size_bytes=1 << 28, minimum_segment_size=2, use_calibration=False, ) diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index 495a9391a1e818a6078988161c9bf72f6143737f..671abba6a6862eb1c7a339d1897be0bb8ff90d30 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -239,8 +239,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _GetConfigProto(self, run_params, graph_state): """Get config proto based on specific settings.""" + conversion_params = self.GetConversionParams(run_params) if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: - conversion_params = self.GetConversionParams(run_params) rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( conversion_params.rewriter_config, conversion_params.max_batch_size, conversion_params.max_workspace_size_bytes, @@ -254,6 +254,9 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: graph_options = config_pb2.GraphOptions() + if conversion_params.rewriter_config is not None: + graph_options.rewrite_options.CopyFrom( + conversion_params.rewriter_config) config = config_pb2.ConfigProto( gpu_options=self._GetGPUOptions(), graph_options=graph_options) @@ -321,16 +324,13 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): return self._RunGraph( run_params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5) - def _GetTrtGraphDef(self, run_params, gdef): + def _GetTrtGraphDef(self, run_params, graph_state, gdef): """Return trt converted graphdef.""" params = self._GetParamsCached() conversion_params = self.GetConversionParams(run_params) logging.info(conversion_params) - config_for_trt = config_pb2.ConfigProto(gpu_options=self._GetGPUOptions()) - if conversion_params.rewriter_config is not None: - config_for_trt.graph_options.rewrite_options.CopyFrom( - conversion_params.rewriter_config) + config_for_trt = self._GetConfigProto(run_params, graph_state) return trt_convert.create_inference_graph( input_graph_def=gdef, outputs=params.input_names + params.output_names, @@ -506,7 +506,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): result = self._RunCalibration(run_params, input_gdef, input_data, calib_config) else: - calib_gdef = self._GetTrtGraphDef(run_params, input_gdef) + calib_gdef = self._GetTrtGraphDef(run_params, GraphState.CALIBRATE, + input_gdef) self._VerifyGraphDef(run_params, calib_gdef, GraphState.CALIBRATE) result = self._RunCalibration(run_params, calib_gdef, input_data, calib_config) @@ -527,7 +528,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): logging.info("Running final inference graph, config:\n%s", str(infer_config)) if not run_params.use_optimizer: - infer_gdef = self._GetTrtGraphDef(run_params, infer_gdef) + infer_gdef = self._GetTrtGraphDef(run_params, GraphState.INFERENCE, + infer_gdef) self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE) result = self._RunGraph(run_params, infer_gdef, input_data, infer_config, diff --git a/tensorflow/contrib/timeseries/examples/predict_test.py b/tensorflow/contrib/timeseries/examples/predict_test.py index 678fd71cd8b94ee0be46e10a9a673de55bd44215..b353f85cb5df0cf961d1900b241e4fa1a84a24b4 100644 --- a/tensorflow/contrib/timeseries/examples/predict_test.py +++ b/tensorflow/contrib/timeseries/examples/predict_test.py @@ -43,10 +43,6 @@ class PeriodTrendExampleTest(test.TestCase): self.assertAllEqual([700], mean.shape) self.assertAllEqual([700], upper_limit.shape) self.assertAllEqual([700], lower_limit.shape) - # Check that variance hasn't blown up too much. This is a relatively good - # indication that training was successful. - self.assertLess(upper_limit[-1] - lower_limit[-1], - 1.5 * (upper_limit[0] - lower_limit[0])) def test_ar(self): (times, observed, all_times, mean, @@ -55,7 +51,6 @@ class PeriodTrendExampleTest(test.TestCase): self.assertAllEqual(all_times.shape, mean.shape) self.assertAllEqual(all_times.shape, upper_limit.shape) self.assertAllEqual(all_times.shape, lower_limit.shape) - self.assertLess((upper_limit - lower_limit).mean(), 4.) if __name__ == "__main__": diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index 4b90b596b28efec83aa349782c4874d79b6817c7..a0c3204d41d8a715fe572223308a32eb39e0b963 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -361,9 +361,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":feature_keys", + ":math_utils", ":model", ":model_utils", - "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index bcadf4094e1e79fff1685515f2bde0b88f717cac..a8d5e1a49dd4313f58f2f515bc3f292ecce5cbd4 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -18,9 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import distributions - from tensorflow.contrib.rnn.python.ops import lstm_ops +from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import model_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures @@ -462,8 +461,8 @@ class ARModel(model.TimeSeriesModel): if self.loss == ARModel.NORMAL_LIKELIHOOD_LOSS: covariance = prediction_ops["covariance"] sigma = math_ops.sqrt(gen_math_ops.maximum(covariance, 1e-5)) - normal = distributions.Normal(loc=targets, scale=sigma) - loss_op = -math_ops.reduce_sum(normal.log_prob(prediction)) + loss_op = -math_ops.reduce_sum( + math_utils.normal_log_prob(targets, sigma, prediction)) else: assert self.loss == ARModel.SQUARED_LOSS, self.loss loss_op = math_ops.reduce_sum(math_ops.square(prediction - targets)) @@ -965,16 +964,11 @@ class AnomalyMixtureARModel(ARModel): anomaly_variance = prediction_ops["anomaly_params"] anomaly_sigma = math_ops.sqrt( gen_math_ops.maximum(anomaly_variance, 1e-5)) - normal = distributions.Normal(loc=targets, scale=anomaly_sigma) - log_prob = normal.log_prob(prediction) + log_prob = math_utils.normal_log_prob(targets, anomaly_sigma, prediction) else: assert self._anomaly_distribution == AnomalyMixtureARModel.CAUCHY_ANOMALY anomaly_scale = prediction_ops["anomaly_params"] - cauchy = distributions.StudentT( - df=array_ops.ones([], dtype=anomaly_scale.dtype), - loc=targets, - scale=anomaly_scale) - log_prob = cauchy.log_prob(prediction) + log_prob = math_utils.cauchy_log_prob(targets, anomaly_scale, prediction) return log_prob def loss_op(self, targets, prediction_ops): @@ -983,8 +977,7 @@ class AnomalyMixtureARModel(ARModel): covariance = prediction_ops["covariance"] # Normal data log probability. sigma = math_ops.sqrt(gen_math_ops.maximum(covariance, 1e-5)) - normal1 = distributions.Normal(loc=targets, scale=sigma) - log_prob1 = normal1.log_prob(prediction) + log_prob1 = math_utils.normal_log_prob(targets, sigma, prediction) log_prob1 += math_ops.log(1 - self._anomaly_prior_probability) # Anomaly log probability. log_prob2 = self._anomaly_log_prob(targets, prediction_ops) diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index aab330643862c1ccf073d2a0e34e1c475b1ec15f..b7375e5055e29efea3f23c3b9b9f3af59f45495b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -21,6 +21,8 @@ from __future__ import print_function import collections import math +import numpy as np + from tensorflow.contrib import lookup from tensorflow.contrib.layers.python.layers import layers @@ -43,6 +45,32 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest +def normal_log_prob(loc, scale, x): + """Computes the Normal log pdf.""" + z = (x - loc) / scale + return -0.5 * (math_ops.square(z) + + np.log(2. * np.pi) + math_ops.log(scale)) + + +def cauchy_log_prob(loc, scale, x): + """Computes the Cauchy log pdf.""" + z = (x - loc) / scale + return (-np.log(np.pi) - math_ops.log(scale) - + math_ops.log1p(math_ops.square(z))) + + +def mvn_tril_log_prob(loc, scale_tril, x): + """Computes the MVN log pdf under tril scale. Doesn't handle batches.""" + x0 = x - loc + z = linalg_ops.matrix_triangular_solve( + scale_tril, x0[..., array_ops.newaxis])[..., 0] + log_det_cov = 2. * math_ops.reduce_sum(math_ops.log( + array_ops.matrix_diag_part(scale_tril)), axis=-1) + d = math_ops.cast(array_ops.shape(scale_tril)[-1], log_det_cov.dtype) + return -0.5 * (math_ops.reduce_sum(math_ops.square(z), axis=-1) + + d * np.log(2. * np.pi) + log_det_cov) + + def clip_covariance( covariance_matrix, maximum_variance_ratio, minimum_variance): """Enforce constraints on a covariance matrix to improve numerical stability. diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index 125750e7639ad40c481472a93353e6fb7055be96..cf5e749042afd83f927a3d22edfd3a9538ab2ffd 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -78,7 +78,6 @@ py_library( srcs = ["kalman_filter.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/timeseries/python/timeseries:math_utils", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -235,7 +234,6 @@ py_library( srcs = ["filtering_postprocessor.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/timeseries/python/timeseries:math_utils", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py index e9e2ac0aaf4c4d6c41f5007662f261af3de9bbd1..3fa2fbd9f77cb887c30fde264815728ca345f45a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py @@ -22,8 +22,6 @@ import abc import six -from tensorflow.contrib import distributions - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.python.framework import dtypes @@ -91,10 +89,10 @@ def cauchy_alternative_to_gaussian(current_times, current_values, outputs): """ del current_times # unused cauchy_scale = math_utils.entropy_matched_cauchy_scale(outputs["covariance"]) - individual_log_pdfs = distributions.StudentT( - df=array_ops.ones([], dtype=current_values.dtype), + individual_log_pdfs = math_utils.cauchy_log_prob( loc=outputs["mean"], - scale=cauchy_scale).log_prob(current_values) + scale=cauchy_scale, + x=current_values) return math_ops.reduce_sum(individual_log_pdfs, axis=1) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py index a614386121e000961bf8b32625a28e1251654320..c0ec797bc5b7c41ca996c807840ce38311201f87 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import distributions - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.python.framework import dtypes @@ -137,9 +135,10 @@ class KalmanFilter(object): with ops.control_dependencies([non_negative_assert]): observation_covariance_cholesky = linalg_ops.cholesky( symmetrized_observation_covariance) - log_prediction_prob = distributions.MultivariateNormalTriL( - predicted_observation, observation_covariance_cholesky).log_prob( - observation) + log_prediction_prob = math_utils.mvn_tril_log_prob( + loc=predicted_observation, + scale_tril=observation_covariance_cholesky, + x=observation) (posterior_state, posterior_state_var) = self.posterior_from_prior_state( prior_state=estimated_state, diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 3cbf265f68f28d374daeeb44183e8e82a796ac7f..b9c225b44a2f03c8f6d1ef23f75cb5df5667a002 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -70,12 +70,15 @@ py_library( srcs_version = "PY2AND3", deps = [ ":async_checkpoint", + ":functional", ":tpu_lib", + ":tpu_ordinal_selector_py", "//tensorflow/contrib/training:training_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:function", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", @@ -155,6 +158,24 @@ tf_gen_op_wrapper_py( ], ) +tf_custom_op_library( + name = "python/ops/_tpu_ordinal_selector.so", + srcs = ["ops/tpu_ordinal_selector_op.cc"], +) + +tf_custom_op_py_library( + name = "tpu_ordinal_selector_py", + srcs = ["ops/gen_tpu_ordinal_selector_op.py"], + dso = [":python/ops/_tpu_ordinal_selector.so"], + kernels = [ + ":tpu_ordinal_selector_op_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":tpu_ordinal_selector_op", + ], +) + tf_gen_op_wrapper_py( name = "tpu_ordinal_selector_op", deps = [ diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index f27ae38e0434991da7475e631be1c6cb4a463118..807cf26fe983b4ebe17695d6f4f90ecfc0e0cbf5 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -33,7 +33,7 @@ setup( long_description='Tools for capture TPU profile', url='https://www.tensorflow.org/tfrc/', author='Google Inc.', - author_email='opensource@google.com', + author_email='packages@tensorflow.org', packages=['cloud_tpu_profiler'], package_data={ 'cloud_tpu_profiler': ['data/*'], diff --git a/tensorflow/contrib/tpu/python/tpu/feature_column.py b/tensorflow/contrib/tpu/python/tpu/feature_column.py index d5d00d628d407bf3bb5312bd54f6ccd13dc37db4..8edf131bc24fd003806263570b63ee8514c49896 100644 --- a/tensorflow/contrib/tpu/python/tpu/feature_column.py +++ b/tensorflow/contrib/tpu/python/tpu/feature_column.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib import math from tensorflow.contrib.tpu.python.tpu import tpu @@ -279,11 +278,10 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): - # TODO(shizhiw, b/112012627, b/112336539): Replace _outside_all_rewrites() - # with outside compilation. - with _outside_all_rewrites(): + def host_computation(): return fc._EmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) + return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): return fc._EmbeddingColumn._get_dense_tensor( @@ -300,13 +298,6 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): return tensor -@contextlib.contextmanager -def _outside_all_rewrites(): - """'Break out' of a tpu.rewrite() (or shard(), etc.).""" - with ops.control_dependencies(None): - yield - - class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._SharedEmbeddingColumn): """Core Shared Embedding Column.""" @@ -385,11 +376,10 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): - # TODO(shizhiw, b/112012627, b/112336539): Replace _outside_all_rewrites() - # with outside compilation. - with _outside_all_rewrites(): + def host_computation(): return fc._SharedEmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) + return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): return fc._SharedEmbeddingColumn._get_dense_tensor( diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index 3e463823c820a3ef8628324f77e1a9caf8d385d5..f5735cecc38b7033f21fc4d4105cfead233379fa 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -185,7 +185,8 @@ def all_worker_devices(session): """Return a list of devices for each worker in the system.""" devices = session.list_devices() return [ - device.name for device in devices + device.name + for device in devices if ':CPU:' in device.name and 'coordinator' not in device.name ] @@ -255,12 +256,14 @@ class WatchdogManager(threading.Thread): self._worker_manager.configure( event_pb2.WorkerHeartbeatRequest( watchdog_config=event_pb2.WatchdogConfig( - timeout_ms=self.shutdown_timeout * 1000,))) + timeout_ms=self.shutdown_timeout * 1000,), + shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) def configure_and_run(self): - logging.info('Enabling watchdog timer with %d second timeout ' - 'and %d second ping interval.', - self.shutdown_timeout, self.ping_interval) + logging.info( + 'Enabling watchdog timer with %d second timeout ' + 'and %d second ping interval.', self.shutdown_timeout, + self.ping_interval) self._reset_manager() self._running = True self.start() @@ -269,7 +272,8 @@ class WatchdogManager(threading.Thread): logging.info('Stopping worker watchdog.') self._worker_manager.configure( event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,))) + watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,), + shutdown_mode=event_pb2.NOT_CONFIGURED)) self._running = False self.join() diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py index 66689cde2fb53c1d9909224fcf38e5e9a7dd729b..bf492e78a15acc92017663a286e8c8f0b2045339 100644 --- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -59,6 +59,7 @@ _REASON_SCALAR_GET_TRACED = 'traced-scalar' _REASON_TENSOR_GET_TRACED = 'traced-tensor' _REASON_USER_INCLUDED = 'traced-user-included' _REASON_USER_EXCLUDED = 'not-traced-user-excluded' +_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path' _REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' _MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' _MARKER_SECTION_END = '!!!!!!! section-end:' @@ -377,10 +378,7 @@ class TensorTracer(object): return True # Reasons for not including following op types: # Assign: cause incorrect result with CPU tracing. - # others: compilation problems. - # TODO(deveci): Check if 'Pack', 'Shape', 'Reshape', 'ArgMin', 'ArgMax' - # are still unsafe now that we have handled int64 tensors. - if op.type in ['Assign', 'Pack', 'Shape', 'Reshape', 'ArgMin', 'ArgMax']: + if op.type in ['Assign']: return True return False @@ -471,7 +469,7 @@ class TensorTracer(object): temporarily_marked_ops, sorted_ops) # pylint: disable=protected-access for ctrl_output_op in op._control_outputs: - # pylint: enable=protected-access + # pylint: enable=protected-access visit(ctrl_output_op, cycle, permanently_marked_ops, temporarily_marked_ops, sorted_ops) temporarily_marked_ops.remove(op) @@ -737,6 +735,59 @@ class TensorTracer(object): self._write_report('%d "%s"\n'%(i, l[i].name)) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) + def _preprocess_traced_tensor(self, tensor): + """Computes NAN/Norm/Max on TPUs before sending to CPU. + + Args: + tensor: The tensor to be traced. + Returns: + A tensor that should be input to the trace_function. + Raises: + RuntimeError: If the trace mode is invalid. + """ + + def _detect_nan_inf(tensor): + """Trace function for detecting any NaN/Inf in the tensor.""" + + if tensor.dtype.is_floating: + output_tensor = math_ops.reduce_any( + gen_math_ops.logical_or( + gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor))) + else: + output_tensor = constant_op.constant(False) + # The shape has to be 1. Set it if it does not have the information. + output_tensor = array_ops.reshape(output_tensor, [1]) + return output_tensor + + def _show_norm(tensor): + tensor = math_ops.cast(tensor, dtypes.float32) + output_tensor = linalg_ops.norm(tensor) + # The shape has to be 1. Set it if it does not have the information. + output_tensor = array_ops.reshape(output_tensor, [1]) + return output_tensor + + def _show_max_abs(tensor): + tensor = math_ops.cast(tensor, dtypes.float32) + output_tensor = math_ops.reduce_max(math_ops.abs(tensor)) + zero = constant_op.constant(0, dtypes.float32) + output_tensor = gen_math_ops.maximum(zero, output_tensor) + # The shape has to be 1. Set it if it does not have the information. + output_tensor = array_ops.reshape(output_tensor, [1]) + return output_tensor + + if self._trace_mode == _TRACE_MODE_NAN_INF: + return _detect_nan_inf(tensor) + if self._trace_mode == _TRACE_MODE_PART_TENSOR: + return tensor + if self._trace_mode == _TRACE_MODE_FULL_TENSOR: + return tensor + if self._trace_mode == _TRACE_MODE_NORM: + return _show_norm(tensor) + if self._trace_mode == _TRACE_MODE_MAX_ABS: + return _show_max_abs(tensor) + raise RuntimeError( + 'Tensor trace fun for %s is not yet implemented' % self._trace_mode) + def _make_tensor_trace_fun(self, tensor_name): """Makes the tensor tracing function called by outside compilation. @@ -787,29 +838,6 @@ class TensorTracer(object): with ops.control_dependencies([print_op]): return array_ops.identity(tensor).op - def _detect_nan_inf(tensor): - """Trace function for detecting any NaN/Inf in the tensor.""" - - if tensor.dtype.is_floating: - output_tensor = math_ops.reduce_any( - gen_math_ops.logical_or(gen_math_ops.is_nan(tensor), - gen_math_ops.is_inf(tensor))) - else: - output_tensor = constant_op.constant(False) - - return _print_tensor(tensor_name, -1, tensor, output_tensor) - - def _show_norm(tensor): - tensor = math_ops.cast(tensor, dtypes.float64) - output_tensor = linalg_ops.norm(tensor) - return _print_tensor(tensor_name, -1, tensor, output_tensor) - - def _show_max_abs(tensor): - output_tensor = math_ops.cast(math_ops.reduce_max(math_ops.abs(tensor)), - dtypes.float64) - zero = constant_op.constant(0, dtypes.float64) - output_tensor = gen_math_ops.maximum(zero, output_tensor) - return _print_tensor(tensor_name, -1, tensor, output_tensor) def _show_part_tensor(tensor): """Trace function for printing part of the tensor.""" @@ -822,21 +850,23 @@ class TensorTracer(object): return _print_tensor(tensor_name, -1, tensor, tensor) - if self._trace_mode == _TRACE_MODE_NAN_INF: - return _detect_nan_inf if self._trace_mode == _TRACE_MODE_PART_TENSOR: return _show_part_tensor - if self._trace_mode == _TRACE_MODE_FULL_TENSOR: + # The input tensor has a shape of "[1]" for _TRACE_MODE_NAN_INF, + # _TRACE_MODE_NORM, and _TRACE_MODE_MAX_ABS, as related computations are + # performed within TPUs and only their results are transferred to CPU. + # Simply, print the full tensor for these trace modes. + if self._trace_mode in [ + _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_FULL_TENSOR, + _TRACE_MODE_MAX_ABS + ]: return _show_full_tensor - if self._trace_mode == _TRACE_MODE_NORM: - return _show_norm - if self._trace_mode == _TRACE_MODE_MAX_ABS: - return _show_max_abs raise RuntimeError('Tensor trace fun for %s is not yet implemented' %self._trace_mode) - def _skip_op(self, op_id, op, user_included, user_excluded): + def _skip_op(self, op_id, op, user_included, user_excluded, + in_exec_path=True): """Returns True if we should not trace Op.""" if user_included: @@ -847,6 +877,10 @@ class TensorTracer(object): self._instrument_records[op.name] = TensorTracer.reason( op_id, _REASON_USER_EXCLUDED) return True + if not in_exec_path: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_NOT_EXECUTED) + return True if not self._inside_op_range(op_id): self._instrument_records[op.name] = TensorTracer.reason( op_id, _REASON_OUTSIDE_OP_RANGE) @@ -889,9 +923,18 @@ class TensorTracer(object): op_id, _REASON_USER_EXCLUDED) return True if not out_tensor.get_shape().is_fully_defined(): - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_DYNAMIC_SHAPE) - return True + # If trace mode is nan-inf, norm or max, then the tensor will be reduced + # to a scalar before the outside compilation call. + if self._trace_mode in [ + _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS + ]: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_TENSOR_GET_TRACED) + return False + else: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_DYNAMIC_SHAPE) + return True rank = len(out_tensor.shape) if rank < 1: # scalar @@ -909,6 +952,40 @@ class TensorTracer(object): op_id, _REASON_TENSOR_GET_TRACED) return False + def _filter_execution_path_operations(self, operations, fetches): + """Returns the set of ops in the execution path to compute given fetches.""" + # If no fetch provided, then return all operations. + if fetches is None: + return set(operations) + # Convert to list, if a single element is provided. + if not isinstance(fetches, (list, tuple)): + fetches = [fetches] + # If a tensor is given as fetch, convert it to op. + op_fetches = [] + for fetch in fetches: + if isinstance(fetch, ops.Operation): + op_fetches.append(fetch) + elif isinstance(fetch, ops.Tensor): + op_fetches.append(fetch.op) + else: + raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' + %fetch) + + execution_path_operations = set(op_fetches) + traverse_stack = list(op_fetches) + while True: + if not traverse_stack: + break + head_op = traverse_stack.pop() + input_ops = [tensor_input.op for tensor_input in head_op.inputs] + input_ops.extend(head_op.control_inputs) + + for input_op in input_ops: + if input_op not in execution_path_operations: + execution_path_operations.add(input_op) + traverse_stack.append(input_op) + return execution_path_operations + def _pre_tracing(self, graph): """Work needs to be done prior to TPU or CPU tracing.""" @@ -950,13 +1027,15 @@ class TensorTracer(object): _TENSOR_TRACER_CHECKPOINT)) return checkpoint_operations - def trace_tpu(self, graph, result_tensor, num_replicas=None): + def trace_tpu(self, graph, result_tensor, num_replicas=None, fetches=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: graph: the graph of Ops executed on the TPU. result_tensor: a result tensor of evaluating the graph. num_replicas: number of replicas used on the TPU. + fetches: the list of fetches given to session.run, used to determine the + ops in execution path. If None, the whole graph will be traced. Returns: A tuple (result_tensor_copy, tracing_ops), where: @@ -985,6 +1064,10 @@ class TensorTracer(object): result_tensor_copy = self._add_replica_id_to_graph(num_replicas, result_tensor) (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) + # Filter out the operations that won't be executed. + # if fetches=None, then ops_in_exec_path = set(operations) + ops_in_exec_path = self._filter_execution_path_operations(operations, + fetches) tracing_ops = [] checkpoint_operations = self._get_checkpoints(graph) @@ -993,18 +1076,23 @@ class TensorTracer(object): continue user_included = self._is_user_included_op(op) user_excluded = self._is_user_excluded_op(op) - if self._skip_op(op_id, op, user_included, user_excluded): + in_exec_path = op in ops_in_exec_path + if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path): continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] if self._skip_tensor(op_id, out_tensor, user_included, user_excluded): continue + # Create the list of consumers before calling _preprocess_traced_tensor. + # Otherwise, adding control input below, will introduce a cycle in the + # graph. consumers = out_tensor.consumers() tensor_name = out_tensor.name - out_tensor = _cast_unsupported_dtypes(out_tensor) + processed_out_tensor = self._preprocess_traced_tensor(out_tensor) + processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor) trace_op = tpu.outside_compilation( - self._make_tensor_trace_fun(tensor_name), out_tensor) + self._make_tensor_trace_fun(tensor_name), processed_out_tensor) if consumers: for consumer_op in consumers: # pylint: disable=protected-access @@ -1050,8 +1138,9 @@ class TensorTracer(object): if self._skip_tensor(op_id, out_tensor, user_included, user_excluded): continue + processed_out_tensor = self._preprocess_traced_tensor(out_tensor) trace_fun = self._make_tensor_trace_fun(out_tensor.name) - trace_call = (trace_fun, [out_tensor]) + trace_call = (trace_fun, [processed_out_tensor]) trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i) tracing_calls[trace_call_key] = trace_call self._post_tracing(succeed, sorted_or_cycle) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 9266d81cf5fc035790062f0e307a5da0b01a9fc1..a267dd435c04c21bbeb3dc09a325267e7b22286e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat +from tensorflow.python.util import nest # Operations that indicate some error in the users graph, e.g. a placeholder @@ -487,7 +488,11 @@ def replicate(computation, computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. + have the same number of inputs. Each input can be a nested structure + containing values that are convertible to tensors. Note that passing an + N-dimension list of compatible values will result in a N-dimention list of + scalar tensors rather than a single Rank-N tensors. If you need different + behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the @@ -526,7 +531,11 @@ def split_compile_and_replicate(computation, computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. + have the same number of inputs. Each input can be a nested structure + containing values that are convertible to tensors. Note that passing an + N-dimension list of compatible values will result in a N-dimention list of + scalar tensors rather than a single Rank-N tensors. If you need different + behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the @@ -580,24 +589,32 @@ def split_compile_and_replicate(computation, if num_replicas == 0: return [] + # Checks all replicas have the same structure. + for i in xrange(1, num_replicas): + nest.assert_same_structure(inputs[0], inputs[i]) + + # Flatten inputs. + flat_inputs = [ + nest.flatten(per_replica_input) for per_replica_input in inputs + ] # Converts inputs to Tensors. - inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] + flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs] # Verifies that all replicas have matching numbers and types of inputs - input_types = [x.dtype for x in inputs[0]] - input_arity = len(input_types) + flat_input_types = [x.dtype for x in flat_inputs[0]] + input_arity = len(inputs[0]) + flat_input_arity = len(flat_input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) - types = [x.dtype for x in inputs[i]] - if types != input_types: - raise ValueError( - "Replicas must have matching input types. Replica 0 had " - "input types {}, replica {} had input types {}".format( - input_types, i, types)) + types = [x.dtype for x in flat_inputs[i]] + if types != flat_input_types: + raise ValueError("Replicas must have matching input types. Replica 0 had " + "input types {}, replica {} had input types {}".format( + flat_input_types, i, types)) arg_error = xla.check_function_argument_count( computation, input_arity, infeed_queue) @@ -620,8 +637,8 @@ def split_compile_and_replicate(computation, # Fan-in: Builds a TPUReplicatedInput node for each input. computation_inputs = [] - for i in range(0, input_arity): - replicas = [inputs[replica][i] for replica in xrange(num_replicas)] + for i in range(0, flat_input_arity): + replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)] computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) @@ -651,6 +668,10 @@ def split_compile_and_replicate(computation, i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access + # Unflatten the computation inputs to match original input structure. + computation_inputs = nest.pack_sequence_as( + structure=inputs[0], flat_sequence=computation_inputs) + # If there is an infeed queue, adds the dequeued values to the # computation's inputs. if infeed_queue is not None: @@ -1092,6 +1113,11 @@ def rewrite(computation, All `Operation`s constructed during `computation` will be executed when evaluating any of the returned output tensors, not just the ones returned. inputs: A list of input tensors or `None` (equivalent to an empty list). + Each input can be a nested structure containing values that are + convertible to tensors. Note that passing an N-dimension list of + compatible values will result in a N-dimention list of scalar tensors + rather than a single Rank-N tensors. If you need different behavior, + convert part of inputs to tensors with `tf.convert_to_tensor`. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to `computation`. device_assignment: if not `None`, a `DeviceAssignment` describing the diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 87a970f0523363426b0da5b12838b797d7f8bebb..075ecbc52a642ebe9c35e0b9ede8bd4ee963aec6 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -31,11 +31,13 @@ import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.tpu.ops import gen_tpu_ordinal_selector_op from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result -from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import error_handling +from tensorflow.contrib.tpu.python.tpu import functional as tpu_functional from tensorflow.contrib.tpu.python.tpu import session_support +from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_context @@ -55,6 +57,7 @@ from tensorflow.python.estimator.export import export_output as export_output_li from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -90,6 +93,7 @@ _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' _REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' +_KEY_WHEN_PREDICTIONS_IS_A_TENSOR = '_key_when_predictions_is_a_tensor' # Ideally _USE_TPU_KEY should be reserved as well. However there are already # models that make use of this key, thus it can not be reserved now to prevent @@ -1303,6 +1307,44 @@ class _InputPipeline(object): logging.warn(err_msg) +def call_computation(computation, + experimental_exported_model_uses_all_cores=True): + """Call computation. + + computation uses a single-core for TPU inference. If + `experimental_exported_model_uses_all_cores` is `True`, this function will + round-robin + computation among all TPU cores visible to the host; otherwise, it will use + a single core. + + Args: + computation: A Python function that takes no inputs and builds computation + graph. If `computation` returns m outputs, this function will return a + list of m Tensors. + experimental_exported_model_uses_all_cores: Whether to round-robin among all + cores visible to the host, or to use a single core. + + Returns: + A list of output tensors. + """ + if experimental_exported_model_uses_all_cores: + # Using `TPUPartitionedCall` makes it possible to target a different + # TPU core with every `Session.run()` call. Note that the entire inference + # graph executes on a single core, and that invocations of this graph + # will round-robin among the cores attached to a host. + @function.Defun() + def tpu_subgraph(): + return computation() + + return tpu_functional.TPUPartitionedCall( + args=tpu_subgraph.captured_inputs, + device_ordinal=gen_tpu_ordinal_selector_op.tpu_ordinal_selector(), + Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg], + f=tpu_subgraph) + else: + return computation() + + class _ModelFnWrapper(object): """A `model_fn` wrapper. @@ -1370,7 +1412,8 @@ class _ModelFnWrapper(object): if tensor_tracer.TensorTracer.is_enabled(): tt = tensor_tracer.TensorTracer() loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(), loss, - self._ctx.num_replicas) + self._ctx.num_replicas, + fetches=[loss, train_op]) # We must run train_op to update the variables prior to running the # outfeed. @@ -2100,7 +2143,8 @@ class TPUEstimator(estimator_lib.Estimator): batch_axis=None, eval_on_tpu=True, export_to_tpu=True, - warm_start_from=None): + warm_start_from=None, + experimental_exported_model_uses_all_cores=False): """Constructs an `TPUEstimator` instance. Args: @@ -2143,12 +2187,20 @@ class TPUEstimator(estimator_lib.Estimator): eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. export_to_tpu: If True, `export_savedmodel()` exports a metagraph for - serving on TPU besides the one on CPU. + serving on TPU besides the one on CPU. Note that unsupported export + modes such as EVAL will be ignored. For those modes, only a CPU model + will be exported. Currently, export_to_tpu only supports PREDICT. warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a `tf.estimator.WarmStartSettings` object to fully configure warm-starting. If the string filepath is provided instead of a `WarmStartSettings`, then all variables are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. + experimental_exported_model_uses_all_cores: Whether to round-robin among + all cores visible to the host which is serving the saved model, or to + use a single core. This is a temporary flag to enable using all TPU + cores for inference with TPUPartitionedCall(). Once outside compilation + is supported in TPUPartitionedCall(), this flag will be enabled by + default. Raises: ValueError: `params` has reserved keys already. @@ -2213,6 +2265,8 @@ class TPUEstimator(estimator_lib.Estimator): use_tpu, eval_on_tpu) self._export_to_tpu = export_to_tpu + self._experimental_exported_model_uses_all_cores = ( + experimental_exported_model_uses_all_cores) self._is_input_fn_invoked = None self._rendezvous = {} @@ -2226,10 +2280,9 @@ class TPUEstimator(estimator_lib.Estimator): export_tags=None, check_variables=True): if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT: - raise NotImplementedError( - 'TPUEstimator only handles mode PREDICT for exporting ' - 'when `export_to_tpu` is `True`; ' - 'got {}.'.format(mode)) + logging.warning('TPUEstimator only handles mode PREDICT for exporting ' + 'when `export_to_tpu` is `True`; Mode {} will be ignored ' + 'for TPU.'.format(mode)) (super(TPUEstimator, self)._add_meta_graph_for_mode( builder, @@ -2240,7 +2293,7 @@ class TPUEstimator(estimator_lib.Estimator): export_tags=export_tags, check_variables=check_variables)) - if self._export_to_tpu: + if self._export_to_tpu and mode == model_fn_lib.ModeKeys.PREDICT: input_receiver_fn_map = { _REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode] } @@ -2269,6 +2322,79 @@ class TPUEstimator(estimator_lib.Estimator): raise ValueError('mode must be {}; ' 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) + computation, capture = self._build_computation_for_inference( + features, labels, mode, config) + tensors = call_computation( + computation, + experimental_exported_model_uses_all_cores=self + ._experimental_exported_model_uses_all_cores) + estimator_spec, export_outputs_dict, predictions_dict, none_indices = ( + capture.get()) + predictions_list = tensors[:len(predictions_dict)] + export_outputs_list_without_none = tensors[len(predictions_dict):] + + # Reinsert `None`s which we've taken out in + # `_build_computation_for_inference()`. + export_outputs_list = [] + while none_indices or export_outputs_list_without_none: + if none_indices and none_indices[0] == len(export_outputs_list): + export_outputs_list.append(None) + none_indices.pop(0) + else: + export_outputs_list.append(export_outputs_list_without_none.pop(0)) + + # Reconstruct `export_outputs` with updated tensors. + new_export_outputs_dict = nest.pack_sequence_as(export_outputs_dict, + export_outputs_list) + export_outputs = estimator_spec.export_outputs + new_export_outputs = collections.OrderedDict( + (k, _clone_export_output_with_tensors(export_outputs[k], v)) + for k, v in six.iteritems(new_export_outputs_dict)) + # Reconstruct `predictions` with updated tensors. + new_predictions = nest.pack_sequence_as(predictions_dict, predictions_list) + if (len(new_predictions) == 1 and + _KEY_WHEN_PREDICTIONS_IS_A_TENSOR in new_predictions): + new_predictions = new_predictions[_KEY_WHEN_PREDICTIONS_IS_A_TENSOR] + + return estimator_spec._replace( + export_outputs=new_export_outputs, predictions=new_predictions) + + def _build_computation_for_inference(self, features, labels, mode, config): + capture = _CapturedObject() + + def computation(): + """Computation to be passed to `TPUPartitionedCall()`.""" + tpu_computation, tpu_capture = self._build_tpu_computation_for_inference( + features, labels, mode, config) + + tensors_on_cpu = tpu.rewrite_for_inference(tpu_computation) + (estimator_spec, export_outputs_dict, export_outputs_list, + predictions_dict) = ( + tpu_capture.get()) + predictions_list = tensors_on_cpu[:len(predictions_dict)] + export_outputs_tpu_on_cpu_list = tensors_on_cpu[len(predictions_dict):] + + # Reconstruct tensors used in export_outputs, with TPU tensors replaced + # with their CPU counterpart returned from `rewrite_for_inference()`. + # `function.Defun()` does not like `None`s in return values, so we leave + # `None`s out but record their positions for later reconstruction. + export_outputs_list_without_none = [] + none_indices = [] + for i, t in enumerate(export_outputs_list): + if t is None: + none_indices.append(i) + else: + export_outputs_list_without_none.append( + export_outputs_tpu_on_cpu_list.pop(0)) + + capture.capture((estimator_spec, export_outputs_dict, predictions_dict, + none_indices)) + return predictions_list + export_outputs_list_without_none + + return computation, capture + + def _build_tpu_computation_for_inference(self, features, labels, mode, + config): capture = _CapturedObject() def computation(): @@ -2289,38 +2415,30 @@ class TPUEstimator(estimator_lib.Estimator): # We pick the TPU tensors out from `export_output` and later return them # from `computation` for rewriting. - tensors_dict = collections.OrderedDict( + export_outputs_dict = collections.OrderedDict( (k, _export_output_to_tensors(v)) for k, v in six.iteritems(estimator_spec.export_outputs)) - tensors = nest.flatten(tensors_dict) - tpu_tensors = [t for t in tensors if t is not None] - - # We cannot return anything other than `tpu_tensors` here so we capture - # the rest for later use. - capture.capture((estimator_spec, tensors_dict, tensors)) - return tpu_tensors - - tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation) - estimator_spec, tensors_dict, tensors = capture.get() - - # Reconstruct `tensors`, but with `tpu_tensors` replaced with - # `tpu_tensors_on_cpu`. - new_tensors = [] - for t in tensors: - if t is None: - new_tensors.append(None) + export_outputs_list = nest.flatten(export_outputs_dict) + export_outputs_tpu_list = [ + t for t in export_outputs_list if t is not None + ] + + if isinstance(estimator_spec.predictions, dict): + predictions_dict = collections.OrderedDict( + (k, v) for k, v in six.iteritems(estimator_spec.predictions)) else: - new_tensors.append(tpu_tensors_on_cpu.pop(0)) + predictions_dict = { + _KEY_WHEN_PREDICTIONS_IS_A_TENSOR: estimator_spec.predictions + } + predictions_list = nest.flatten(predictions_dict) - # Reconstruct `tensors_dict`. - new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) - # Reconstruct `export_outputs`. - export_outputs = estimator_spec.export_outputs - new_export_outputs = collections.OrderedDict( - (k, _clone_export_output_with_tensors(export_outputs[k], v)) - for k, v in six.iteritems(new_tensors_dict)) + # We cannot return everything we want through the return values, so + # capture the rest here for later use. + capture.capture((estimator_spec, export_outputs_dict, export_outputs_list, + predictions_dict)) + return predictions_list + export_outputs_tpu_list - return estimator_spec._replace(export_outputs=new_export_outputs) + return computation, capture def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -2538,7 +2656,7 @@ class TPUEstimator(estimator_lib.Estimator): if self._log_every_n_steps is not None: examples_hook = ExamplesPerSecondHook( ctx.global_batch_size, - output_dir=self.model_dir, + output_dir=self.model_dir if config.save_summary_steps else None, every_n_steps=self._log_every_n_steps) if ctx.is_running_on_cpu(is_export_mode=is_export_mode): diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index bcc177601b95172b05d327247bd370c2f8b65d59..27f0d9b2e38c433d4fb4573285ecb8c9946112e8 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -499,6 +499,7 @@ class HParams(object): value: New value of the hyperparameter. Raises: + KeyError: If the hyperparameter doesn't exist. ValueError: If there is a type mismatch. """ param_type, is_list = self._hparam_types[name] @@ -517,6 +518,8 @@ class HParams(object): def del_hparam(self, name): """Removes the hyperparameter with key 'name'. + Does nothing if it isn't present. + Args: name: Name of the hyperparameter. """ @@ -525,19 +528,20 @@ class HParams(object): del self._hparam_types[name] def parse(self, values): - """Override hyperparameter values, parsing new values from a string. + """Override existing hyperparameter values, parsing new values from a string. See parse_values for more detail on the allowed format for values. Args: - values: String. Comma separated list of `name=value` pairs where - 'value' must follow the syntax described above. + values: String. Comma separated list of `name=value` pairs where 'value' + must follow the syntax described above. Returns: The `HParams` instance. Raises: - ValueError: If `values` cannot be parsed. + ValueError: If `values` cannot be parsed or a hyperparameter in `values` + doesn't exist. """ type_map = dict() for name, t in self._hparam_types.items(): @@ -548,7 +552,7 @@ class HParams(object): return self.override_from_dict(values_map) def override_from_dict(self, values_dict): - """Override hyperparameter values, parsing new values from a dictionary. + """Override existing hyperparameter values, parsing new values from a dictionary. Args: values_dict: Dictionary of name:value pairs. @@ -557,6 +561,7 @@ class HParams(object): The `HParams` instance. Raises: + KeyError: If a hyperparameter in `values_dict` doesn't exist. ValueError: If `values_dict` cannot be parsed. """ for name, value in values_dict.items(): @@ -596,7 +601,7 @@ class HParams(object): sort_keys=sort_keys) def parse_json(self, values_json): - """Override hyperparameter values, parsing new values from a json object. + """Override existing hyperparameter values, parsing new values from a json object. Args: values_json: String containing a json object of name:value pairs. @@ -605,6 +610,7 @@ class HParams(object): The `HParams` instance. Raises: + KeyError: If a hyperparameter in `values_json` doesn't exist. ValueError: If `values_json` cannot be parsed. """ values_map = json.loads(values_json) diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index c272a2ac144068cfb7355c2647eebf5bd0ce9d50..fc6e38ab4a5243cb7502f4ca42db03cbfd342a40 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -419,7 +419,7 @@ def create_train_op(total_loss, update_ops = set(update_ops) if not global_update_ops.issubset(update_ops): logging.warning('update_ops in create_train_op does not contain all the ' - ' update_ops in GraphKeys.UPDATE_OPS') + 'update_ops in GraphKeys.UPDATE_OPS') # Make sure update_ops are computed before total_loss. if update_ops: diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 2abaadc657858c800db61d08af2d47e8324a546b..e6af9211b55097303e3fefba1ceb7b9e53a37d90 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1371,7 +1371,7 @@ cc_library( # This includes implementations of all kernels built into TensorFlow. cc_library( - name = "all_kernels_statically_linked", + name = "all_kernels_impl", visibility = ["//visibility:private"], deps = [ "//tensorflow/core/kernels:array", @@ -1387,7 +1387,6 @@ cc_library( "//tensorflow/core/kernels:ctc_ops", "//tensorflow/core/kernels:cudnn_rnn_kernels", "//tensorflow/core/kernels:data_flow", - "//tensorflow/core/kernels:dataset_ops", "//tensorflow/core/kernels:decode_proto_op", "//tensorflow/core/kernels:encode_proto_op", "//tensorflow/core/kernels:fake_quant_ops", @@ -1398,7 +1397,6 @@ cc_library( "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", "//tensorflow/core/kernels:linalg", - "//tensorflow/core/kernels:list_kernels", "//tensorflow/core/kernels:lookup", "//tensorflow/core/kernels:logging", "//tensorflow/core/kernels:manip", @@ -1461,8 +1459,13 @@ cc_library( visibility = ["//visibility:public"], deps = if_dynamic_kernels( [], - otherwise = [":all_kernels_statically_linked"], - ), + otherwise = [":all_kernels_impl"], + ) + [ + # TODO(gunan): Work on the API between these and rest of TF and make + # these also dynamically loading. + "//tensorflow/core/kernels:dataset_ops", # Depends on grappler + "//tensorflow/core/kernels:list_kernels", # Depends on variant_op_registry.h + ], ) tf_cuda_library( @@ -1747,6 +1750,7 @@ cc_library( cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], @@ -2006,6 +2010,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "framework/step_stats_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "framework/step_stats.proto", + visibility = ["//visibility:public"], +) + tf_pyclif_proto_library( name = "framework/types_pyclif", proto_lib = ":protos_all_cc", @@ -2183,6 +2194,7 @@ cc_library( ], }), deps = tf_additional_lib_deps() + [ + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "//third_party/eigen3", "@com_google_absl//absl/base:core_headers", @@ -2337,7 +2349,12 @@ cc_library( cc_library( name = "tflite_portable_logging", - srcs = [], + srcs = [ + ] + if_ios([ + "platform/default/logging.cc", + "platform/env_time.cc", + "platform/posix/env_time.cc", + ]), hdrs = [ "lib/bfloat16/bfloat16.h", "platform/default/integral_types.h", @@ -2346,7 +2363,7 @@ cc_library( "platform/macros.h", "platform/platform.h", "platform/types.h", - ] + if_windows(["platform/windows/integral_types.h"]), + ] + if_windows(["platform/windows/integral_types.h"]) + if_ios(["platform/env_time.h"]), copts = tf_copts(), linkopts = ["-ldl"], deps = [ @@ -2801,6 +2818,7 @@ tf_cuda_library( ":proto_text", ":protos_all_cc", "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -2815,12 +2833,16 @@ CORE_CPU_BASE_HDRS = GRAPH_HDRS + [ "framework/versions.h", "common_runtime/process_function_library_runtime.h", "common_runtime/function.h", + "common_runtime/scoped_allocator.h", + "common_runtime/scoped_allocator_mgr.h", ] tf_cuda_library( name = "core_cpu_base", srcs = [ "common_runtime/eval_const_tensor.cc", + "common_runtime/scoped_allocator.cc", + "common_runtime/scoped_allocator_mgr.cc", "common_runtime/shape_refiner.cc", "common_runtime/shape_refiner.h", "framework/versions.h", @@ -2880,6 +2902,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/mkl_cpu_allocator.h", "common_runtime/optimization_registry.h", "common_runtime/pending_counts.h", + "common_runtime/partitioning_utils.h", "common_runtime/placer.h", "common_runtime/process_util.h", "common_runtime/profile_handler.h", @@ -2887,8 +2910,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/rendezvous_mgr.h", "common_runtime/rendezvous_util.h", "common_runtime/ring_reducer.h", - "common_runtime/scoped_allocator.h", - "common_runtime/scoped_allocator_mgr.h", "common_runtime/session_factory.h", "common_runtime/single_threaded_cpu_device.h", "common_runtime/stats_publisher_interface.h", @@ -2936,6 +2957,7 @@ tf_cuda_library( "common_runtime/mkl_cpu_allocator.cc", "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", + "common_runtime/partitioning_utils.cc", "common_runtime/placer.cc", "common_runtime/pool_allocator.cc", "common_runtime/process_function_library_runtime.cc", @@ -2945,8 +2967,6 @@ tf_cuda_library( "common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_util.cc", "common_runtime/ring_reducer.cc", - "common_runtime/scoped_allocator.cc", - "common_runtime/scoped_allocator_mgr.cc", "common_runtime/session.cc", "common_runtime/session_factory.cc", "common_runtime/session_options.cc", @@ -2974,8 +2994,9 @@ tf_cuda_library( ":proto_text", ":protos_all_cc", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "//third_party/eigen3", - "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/utils:functions", ] + mkl_deps(), alwayslink = 1, ) @@ -3490,6 +3511,29 @@ tf_cc_test( ], ) +tf_cc_test( + name = "platform_fake_python_env_test", + size = "small", + srcs = ["platform/fake_python_env_test.cc"], + args = [ + "/some/path/to/pythontest.runfiles/org_tensorflow/stuff/to/run.py", + ], + tags = [ + "local", + "no_windows", + "nogpu", + "nomac", + "notap", + ], + deps = [ + ":lib", + ":lib_internal", + ":lib_test_internal", + ":test", + ":test_main", + ], +) + tf_cc_test( name = "platform_abi_test", size = "small", @@ -4198,7 +4242,7 @@ tf_cc_test( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "common_runtime_process_function_library_runtime_test", size = "small", srcs = ["common_runtime/process_function_library_runtime_test.cc"], @@ -4207,6 +4251,7 @@ tf_cc_test( ":core_cpu", ":core_cpu_internal", ":framework", + ":framework_internal", ":lib", ":test", ":test_main", @@ -4215,6 +4260,7 @@ tf_cc_test( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:resource_variable_ops", ], ) @@ -4256,6 +4302,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "common_runtime_partitioning_utils_test", + size = "small", + srcs = ["common_runtime/partitioning_utils_test.cc"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":lib", + ":ops", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:identity_op", + ], +) + tf_cuda_cc_test( name = "common_runtime_direct_session_test", size = "small", diff --git a/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt index 070d6adb978e4a62e7209f299dba08515aa21e83..d0794de4ba4a174838547865e4f1692cff503052 100644 --- a/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt @@ -33,6 +33,15 @@ END name: "padding" description: <