diff --git a/CODEOWNERS b/CODEOWNERS index 54a61a4d72c40d297d90d53e223f64f813d9167d..cb3fa2312405ce44d5dfc30ea4164740f436e07e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,7 +1,7 @@ # Where component owners are known, add them here. /tenosrflow/core/debug @caisq -/tensorflow/core/nccl/ @azaks @csigg +/tensorflow/core/nccl/ @azaks2 @chsigg /tensorflow/core/platform/windows/ @mrry /tensorflow/core/platform/s3 @yongtang /tensorflow/go @asimshankar @@ -51,13 +51,13 @@ /tensorflow/contrib/pi_examples/ @maciekcc /tensorflow/contrib/quantization/ @petewarden /tensorflow/contrib/rnn/ @ebrevdo @scottzhu -/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl +/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie /tensorflow/contrib/seq2seq/ @ebrevdo @lmthang /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh /tensorflow/contrib/slim/ @sguada @thenbasilmanran /tensorflow/contrib/stateless/ @girving @alextp /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank -/tensorflow/contrib/tensorrt/ @aaroey +/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2 # NEED OWNER: /tensorflow/contrib/testing/ /tensorflow/contrib/timeseries/ @allenlavoie /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj diff --git a/configure.py b/configure.py index 3a7999fbd65cde41f8fa9c61ef65130db01dae01..6c905a0be3d685b5921dfbc5bddfbe6471a82625 100644 --- a/configure.py +++ b/configure.py @@ -1565,7 +1565,7 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.15.0', '0.19.2') + check_bazel_version('0.15.0', '0.20.0') reset_tf_configure_bazelrc() # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index f653e581bf3beda9fdbf8fb7905a4f9fe170e7fb..25df970ecab0757f23465ab19e7f45de0c759458 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -175,6 +175,34 @@ tf_cuda_library( ], ) +tf_cuda_library( + name = "env", + srcs = [ + "env.cc", + ], + hdrs = [ + "env.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:platform_env", + "//tensorflow/core:lib", + ], + "//conditions:default": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:framework", + "//tensorflow/core:platform_env", + "//tensorflow/core:lib", + ], + }) + [":c_api_internal"], +) + tf_cuda_library( name = "kernels", srcs = [ @@ -188,10 +216,14 @@ tf_cuda_library( deps = select({ "//tensorflow:android": [ ":c_api", + ":c_api_internal", + ":tf_status_helper", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ ":c_api", + ":c_api_internal", + ":tf_status_helper", "//tensorflow/core:framework", ], }), @@ -330,6 +362,27 @@ tf_kernel_library( alwayslink = 1, ) +tf_cuda_cc_test( + name = "env_test", + size = "small", + srcs = ["env_test.cc"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + ":env", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cuda_cc_test( name = "kernels_test", size = "small", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index f13e8777dff164bcd8eedf46310ae846abd0c804..94d18eb8b04e3534be547aca5cfbb32da40ffbf6 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -136,16 +136,22 @@ const char* TF_Message(const TF_Status* s) { namespace { class TF_ManagedBuffer : public TensorBuffer { public: - void* data_; - size_t len_; - void (*deallocator_)(void* data, size_t len, void* arg); - void* deallocator_arg_; + TF_ManagedBuffer(void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg) + : TensorBuffer(data), + len_(len), + deallocator_(deallocator), + deallocator_arg_(deallocator_arg) {} + + const size_t len_; + void (*const deallocator_)(void* data, size_t len, void* arg); + void* const deallocator_arg_; ~TF_ManagedBuffer() override { - (*deallocator_)(data_, len_, deallocator_arg_); + (*deallocator_)(data(), len_, deallocator_arg_); } - void* data() const override { return data_; } size_t size() const override { return len_; } TensorBuffer* root_buffer() override { return this; } void FillAllocationDescription(AllocationDescription* proto) const override { @@ -199,8 +205,7 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, dimvec[i] = static_cast(dims[i]); } - TF_ManagedBuffer* buf = new TF_ManagedBuffer; - buf->len_ = len; + TF_ManagedBuffer* buf = nullptr; if (dtype != TF_STRING && dtype != TF_RESOURCE && tensorflow::DataTypeCanUseMemcpy(static_cast(dtype)) && reinterpret_cast(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) != @@ -212,17 +217,15 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, // // Other types have the same representation, so copy only if it is safe to // do so. - buf->data_ = allocate_tensor("TF_NewTensor", len); - std::memcpy(buf->data_, data, len); - buf->deallocator_ = deallocate_buffer; - buf->deallocator_arg_ = nullptr; + buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len, + deallocate_buffer, nullptr); + std::memcpy(buf->data(), data, len); // Free the original buffer. deallocator(data, len, deallocator_arg); } else { - buf->data_ = data; - buf->deallocator_ = deallocator; - buf->deallocator_arg_ = deallocator_arg; + buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg); } + TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf}; size_t elem_size = TF_DataTypeSize(dtype); if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) { @@ -477,9 +480,9 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { CHECK_EQ(nelems, 0); static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), "64-bit int types should match in size"); - return TF_NewTensor(dtype, reinterpret_cast(dims.data()), - shape.dims(), reinterpret_cast(&empty), 0, - [](void*, size_t, void*) {}, nullptr); + return TF_NewTensor( + dtype, reinterpret_cast(dims.data()), shape.dims(), + reinterpret_cast(&empty), 0, [](void*, size_t, void*) {}, nullptr); } // Non-static for testing. @@ -1592,18 +1595,20 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, break; \ } - LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0; - for (int i = 0; i < attr->list().s_size(); - ++i) { metadata.total_size += attr->list().s(i).size(); }); + LIST_CASE( + s, TF_ATTR_STRING, metadata.total_size = 0; + for (int i = 0; i < attr->list().s_size(); + ++i) { metadata.total_size += attr->list().s(i).size(); }); LIST_CASE(i, TF_ATTR_INT); LIST_CASE(f, TF_ATTR_FLOAT); LIST_CASE(b, TF_ATTR_BOOL); LIST_CASE(type, TF_ATTR_TYPE); - LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0; - for (int i = 0; i < attr->list().shape_size(); ++i) { - const auto& s = attr->list().shape(i); - metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); - }); + LIST_CASE( + shape, TF_ATTR_SHAPE, metadata.total_size = 0; + for (int i = 0; i < attr->list().shape_size(); ++i) { + const auto& s = attr->list().shape(i); + metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); + }); LIST_CASE(tensor, TF_ATTR_TENSOR); LIST_CASE(tensor, TF_ATTR_FUNC); #undef LIST_CASE diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 3d56268110edbe96616201d15a69cc8c84d3115a..c7abba85521fccec07983cd5ab4f94a8368d6181 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -91,7 +91,7 @@ extern "C" { // -------------------------------------------------------------------------- // TF_Version returns a string describing version information of the // TensorFlow library. TensorFlow using semantic versioning. -TF_CAPI_EXPORT extern const char* TF_Version(); +TF_CAPI_EXPORT extern const char* TF_Version(void); // -------------------------------------------------------------------------- // TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. @@ -157,7 +157,7 @@ typedef enum TF_Code { typedef struct TF_Status TF_Status; // Return a new status object. -TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(); +TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void); // Delete a previously created status object. TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*); @@ -196,7 +196,7 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len); // Useful for passing *out* a protobuf. -TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(); +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void); TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); @@ -305,7 +305,7 @@ TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len); typedef struct TF_SessionOptions TF_SessionOptions; // Return a new options object. -TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(); +TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void); // Set the target in TF_SessionOptions.options. // target can be empty, a single entry, or a comma separated list of entries. @@ -338,7 +338,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); typedef struct TF_Graph TF_Graph; // Return a new graph object. -TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(); +TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void); // Destroy an options object. Graph will be deleted once no more // TFSession's are referencing it. @@ -890,7 +890,8 @@ TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, // TF_GraphImportGraphDef. typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; -TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(); +TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions( + void); TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_ImportGraphDefOptions* opts); @@ -1611,7 +1612,7 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // // The data in the buffer will be the serialized OpList proto for ops registered // in this address space. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void); // TF_ApiDefMap encapsulates a collection of API definitions for an operation. // diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 80c8bfe594c4c89606efd01bec7f50e7a86b5bda..3e3a485eb763b871b0551414c4ef04746b2ed9a3 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -239,7 +239,7 @@ TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv); // Platform-specific implementation to return an unused port. (This should used // in tests only.) -TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(); +TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void); // Fast path method that makes constructing a single scalar tensor require less // overhead and copies. diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 8d6c8d958d5961fce817156a14eb2b2940c1f2f0..f80ae5a6d02d4d613c95cf8486e0fc0aeed3affc 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -48,7 +48,7 @@ extern "C" { typedef struct TFE_ContextOptions TFE_ContextOptions; // Return a new options object. -TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(); +TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(void); // Set the config in TF_ContextOptions.options. // config should be a serialized tensorflow.ConfigProto proto. diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c35ff9001d0ee1ab0fbae9e1bcc07116fab1065 --- /dev/null +++ b/tensorflow/c/env.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +struct TF_StringStream { + std::vector<::tensorflow::string>* list; + size_t position; +}; + +void TF_CreateDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->CreateDir(dirname)); +} + +void TF_DeleteDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteDir(dirname)); +} + +void TF_DeleteRecursively(const char* dirname, int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, TF_Status* status) { + ::tensorflow::int64 f, d; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteRecursively(dirname, &f, &d)); + *undeleted_file_count = f; + *undeleted_dir_count = d; +} + +void TF_FileStat(const char* filename, TF_FileStatistics* stats, + TF_Status* status) { + ::tensorflow::FileStatistics cc_stats; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->Stat(filename, &cc_stats); + ::tensorflow::Set_TF_Status_from_Status(status, s); + if (s.ok()) { + stats->length = cc_stats.length; + stats->mtime_nsec = cc_stats.mtime_nsec; + stats->is_directory = cc_stats.is_directory; + } +} + +void TF_NewWritableFile(const char* filename, TF_WritableFileHandle** handle, + TF_Status* status) { + std::unique_ptr<::tensorflow::WritableFile> f; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->NewWritableFile(filename, &f); + ::tensorflow::Set_TF_Status_from_Status(status, s); + + if (s.ok()) { + *handle = reinterpret_cast(f.release()); + } +} + +void TF_CloseWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Close()); + delete cc_file; +} + +void TF_SyncWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Sync()); +} + +void TF_FlushWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Flush()); +} + +void TF_AppendWritableFile(TF_WritableFileHandle* handle, const char* data, + size_t length, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, cc_file->Append(::tensorflow::StringPiece{data, length})); +} + +void TF_DeleteFile(const char* filename, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteFile(filename)); +} + +bool TF_StringStreamNext(TF_StringStream* list, const char** result) { + if (list->position >= list->list->size()) { + *result = nullptr; + return false; + } + + *result = list->list->at(list->position++).c_str(); + return true; +} + +void TF_StringStreamDone(TF_StringStream* list) { + delete list->list; + delete list; +} +TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) { + auto* children = new std::vector<::tensorflow::string>; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->GetChildren(dirname, children)); + + auto* list = new TF_StringStream; + list->list = children; + list->position = 0; + return list; +} + +TF_StringStream* TF_GetLocalTempDirectories() { + auto* tmpdirs = new std::vector<::tensorflow::string>; + + ::tensorflow::Env::Default()->GetLocalTempDirectories(tmpdirs); + + auto* list = new TF_StringStream; + list->list = tmpdirs; + list->position = 0; + return list; +} + +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) { + return ::tensorflow::Env::Default()->NowNanos(); +} + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) { + return ::tensorflow::Env::Default()->NowMicros(); +} + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) { + return ::tensorflow::Env::Default()->NowSeconds(); +} + +void TF_DefaultThreadOptions(TF_ThreadOptions* options) { + options->stack_size = 0; + options->guard_size = 0; + options->numa_node = -1; +} + +TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, void (*work_func)(void*), + void* param) { + ::tensorflow::ThreadOptions cc_options; + cc_options.stack_size = options->stack_size; + cc_options.guard_size = options->guard_size; + cc_options.numa_node = options->numa_node; + return reinterpret_cast(::tensorflow::Env::Default()->StartThread( + cc_options, thread_name, [=]() { (*work_func)(param); })); +} + +void TF_JoinThread(TF_Thread* thread) { + // ::tensorflow::Thread joins on destruction + delete reinterpret_cast<::tensorflow::Thread*>(thread); +} diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h new file mode 100644 index 0000000000000000000000000000000000000000..15652353cd7e1f1e7d7a4c665703c0166682d790 --- /dev/null +++ b/tensorflow/c/env.h @@ -0,0 +1,194 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#ifndef TENSORFLOW_C_ENV_H_ +#define TENSORFLOW_C_ENV_H_ + +#include "tensorflow/c/c_api.h" + +// -------------------------------------------------------------------------- +// C API for tensorflow::Env. + +struct TF_WritableFileHandle; +struct TF_StringStream; +struct TF_Thread; + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_FileStatistics { + // The length of the file in bytes. + int64_t length; + // The last modified time in nanoseconds. + int64_t mtime_nsec; + // Whether the name refers to a directory. + bool is_directory; +} TF_FileStatistics; + +typedef struct TF_ThreadOptions { + // Thread stack size to use (in bytes), zero implies that the system default + // will be used. + size_t stack_size; + + // Guard area size to use near thread stacks to use (in bytes), zero implies + // that the system default will be used. + size_t guard_size; + + // The NUMA node to use, -1 implies that there should be no NUMA affinity for + // this thread. + int numa_node; +} TF_ThreadOptions; + +// Creates the specified directory. Typical status code are: +// * TF_OK - successfully created the directory +// * TF_ALREADY_EXISTS - directory already exists +// * TF_PERMISSION_DENIED - dirname is not writable +TF_CAPI_EXPORT extern void TF_CreateDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory. Typical status codes are: +// * TF_OK - successfully deleted the directory +// * TF_FAILED_PRECONDITION - the directory is not empty +TF_CAPI_EXPORT extern void TF_DeleteDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory and all subdirectories and files underneath +// it. This is accomplished by traversing the directory tree rooted at dirname +// and deleting entries as they are encountered. +// +// If dirname itself is not readable or does not exist, *undeleted_dir_count is +// set to 1, *undeleted_file_count is set to 0 and an appropriate status (e.g. +// TF_NOT_FOUND) is returned. +// +// If dirname and all its descendants were successfully deleted, TF_OK is +// returned and both error counters are set to zero. +// +// Otherwise, while traversing the tree, undeleted_file_count and +// undeleted_dir_count are updated if an entry of the corresponding type could +// not be deleted. The returned error status represents the reason that any one +// of these entries could not be deleted. +// +// Typical status codes: +// * TF_OK - dirname exists and we were able to delete everything underneath +// * TF_NOT_FOUND - dirname doesn't exist +// * TF_PERMISSION_DENIED - dirname or some descendant is not writable +// * TF_UNIMPLEMENTED - some underlying functions (like Delete) are not +// implemented +TF_CAPI_EXPORT extern void TF_DeleteRecursively(const char* dirname, + int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, + TF_Status* status); + +// Obtains statistics for the given path. If status is TF_OK, *stats is +// updated, otherwise it is not touched. +TF_CAPI_EXPORT extern void TF_FileStat(const char* filename, + TF_FileStatistics* stats, + TF_Status* status); + +// Creates or truncates the given filename and returns a handle to be used for +// appending data to the file. If status is TF_OK, *handle is updated and the +// caller is responsible for freeing it (see TF_CloseWritableFile). +TF_CAPI_EXPORT extern void TF_NewWritableFile(const char* filename, + TF_WritableFileHandle** handle, + TF_Status* status); + +// Closes the given handle and frees its memory. If there was a problem closing +// the file, it is indicated by status. Memory is freed in any case. +TF_CAPI_EXPORT extern void TF_CloseWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Syncs content of the handle to the filesystem. Blocks waiting for the +// filesystem to indicate that the content has been persisted. +TF_CAPI_EXPORT extern void TF_SyncWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Flush local buffers to the filesystem. If the process terminates after a +// successful flush, the contents may still be persisted, since the underlying +// filesystem may eventually flush the contents. If the OS or machine crashes +// after a successful flush, the contents may or may not be persisted, depending +// on the implementation. +TF_CAPI_EXPORT extern void TF_FlushWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Appends the given bytes to the file. Any failure to do so is indicated in +// status. +TF_CAPI_EXPORT extern void TF_AppendWritableFile(TF_WritableFileHandle* handle, + const char* data, + size_t length, + TF_Status* status); + +// Deletes the named file and indicates whether successful in *status. +TF_CAPI_EXPORT extern void TF_DeleteFile(const char* filename, + TF_Status* status); + +// Retrieves the next item from the given TF_StringStream and places a pointer +// to it in *result. If no more items are in the list, *result is set to NULL +// and false is returned. +// +// Ownership of the items retrieved with this function remains with the library. +// Item points are invalidated after a call to TF_StringStreamDone. +TF_CAPI_EXPORT extern bool TF_StringStreamNext(TF_StringStream* list, + const char** result); + +// Frees the resources associated with given string list. All pointers returned +// by TF_StringStreamNext are invalid after this call. +TF_CAPI_EXPORT extern void TF_StringStreamDone(TF_StringStream* list); + +// Retrieves the list of children of the given directory. You can iterate +// through the list with TF_StringStreamNext. The caller is responsible for +// freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename, + TF_Status* status); + +// Retrieves a list of directory names on the local machine that may be used for +// temporary storage. You can iterate through the list with TF_StringStreamNext. +// The caller is responsible for freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void); + +// Returns the number of nanoseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void); + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void); + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void); + +// Populates a TF_ThreadOptions struct with system-default values. +TF_CAPI_EXPORT extern void TF_DefaultThreadOptions(TF_ThreadOptions* options); + +// Returns a new thread that is running work_func and is identified +// (for debugging/performance-analysis) by thread_name. +// +// The given param (which may be null) is passed to work_func when the thread +// starts. In this way, data may be passed from the thread back to the caller. +// +// Caller takes ownership of the result and must call TF_JoinThread on it +// eventually. +TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, + void (*work_func)(void*), + void* param); + +// Waits for the given thread to finish execution, then deletes it. +TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread); + +#ifdef __cplusplus +} +#endif + +#endif // TENSORFLOW_C_ENV_H_ diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..687ad024137352662759ec1f43df87e89faca353 --- /dev/null +++ b/tensorflow/c/env_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) + +TEST(TestEnv, TestDirHandling) { + TF_StringStream* tempdirs = TF_GetLocalTempDirectories(); + const char* tempdir; + bool found = false; + while (TF_StringStreamNext(tempdirs, &tempdir)) { + found = true; + + TF_Status* s = TF_NewStatus(); + + ::tensorflow::string dirpath = + ::tensorflow::io::JoinPath(tempdir, "somedir"); + TF_CreateDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_CreateDir failed for " << dirpath << ": " + << TF_Message(s); + + ::tensorflow::string filepath = + ::tensorflow::io::JoinPath(dirpath, "somefile.txt"); + TF_WritableFileHandle* handle; + TF_NewWritableFile(filepath.c_str(), &handle, s); + ASSERT_TF_OK(s) << "NewWritableFile failed for " << filepath << ": " + << TF_Message(s); + + const char* data = "Hello, world!\n"; + TF_AppendWritableFile(handle, data, strlen(data), s); + ASSERT_TF_OK(s) << "TF_AppendWritableFile failed to append data to file at " + << filepath << ": " << TF_Message(s); + + TF_CloseWritableFile(handle, s); + ASSERT_TF_OK(s) << "TF_CloseWritableFile failed to close handle to " + << filepath << ": " << TF_Message(s); + + TF_StringStream* children = TF_GetChildren(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_GetChildren failed for " << dirpath; + const char* childpath; + ASSERT_TRUE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(::tensorflow::string(childpath), "somefile.txt"); + // There should only be one file in this directory. + ASSERT_FALSE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(childpath, nullptr); + TF_StringStreamDone(children); + + TF_FileStatistics stats; + TF_FileStat(filepath.c_str(), &stats, s); + ASSERT_EQ(stats.length, strlen(data)); + ASSERT_FALSE(stats.is_directory); + ASSERT_GT(stats.mtime_nsec, 0); + + // Trying to delete a non-empty directory should fail. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_NE(TF_OK, TF_GetCode(s)) + << "TF_DeleteDir unexpectedly succeeded with a non-empty directory " + << dirpath; + + TF_DeleteFile(filepath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteFile failed for " << filepath << ": " + << TF_Message(s); + + // Now deleting the directory should work. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteDir failed for " << dirpath << ": " + << TF_Message(s); + + TF_DeleteStatus(s); + break; + } + + ASSERT_TRUE(found) << "expected at least one temp dir"; + + TF_StringStreamDone(tempdirs); +} + +TEST(TestEnv, TestTimeFunctions) { + ASSERT_GE(TF_NowSeconds(), 946684800); // Midnight Jan 1, 2000 + ASSERT_GE(TF_NowMicros(), 946684800 * 1e6); + ASSERT_GE(TF_NowNanos(), 946684800 * 1e9); +} + +namespace { + +struct SomeThreadData { + ::tensorflow::mutex mu; + bool did_work = false; +}; + +void SomeThreadFunc(void* data) { + auto* real_data = static_cast(data); + ::tensorflow::mutex_lock l(real_data->mu); + real_data->did_work = true; +} + +} // namespace + +TEST(TestEnv, TestThreads) { + TF_ThreadOptions options; + TF_DefaultThreadOptions(&options); + SomeThreadData data; + TF_Thread* thread = + TF_StartThread(&options, "SomeThreadName", &SomeThreadFunc, &data); + TF_JoinThread(thread); + ::tensorflow::mutex_lock l(data.mu); + ASSERT_TRUE(data.did_work); +} diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index ca69345264607ac689fb556b4f5c9bc08ea5eb88..2a4eaecb6cf2740a522b1e849d1306ebde6c4577 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -15,7 +15,9 @@ limitations under the License. #include +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/kernels.h" +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -116,3 +118,43 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, TF_SetStatus(status, TF_OK, ""); } + +int TF_NumInputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_inputs(); +} + +int TF_NumOutputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_outputs(); +} + +void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (i < 0 || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + return; + } + const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i)); + TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status); + if (TF_GetCode(status) == TF_OK) { + *tensor = result; + } +} + +void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (i < 0 || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + return; + } + ::tensorflow::Tensor cc_tensor; + ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, s); + if (s.ok()) { + cc_ctx->set_output(i, cc_tensor); + } +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 2518789a3c141755d0b3373d53642c487331f68b..1a91aa184f11ac8e45b38a1d106c7b445747a7c1 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -85,6 +85,32 @@ TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name, // builder is not registered with TensorFlow via TF_RegisterKernelBuilder. TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); +// -------------------------------------------------------------------------- +// OpKernelContext routines + +// TF_NumInputs returns the number of inputs available in ctx. +TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx); + +// TF_NumOutputs returns the number of outputs to be placed in *ctx by the +// kernel. +TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx); + +// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is +// populated and its ownership is passed to the caller. In any other case, +// *tensor is not modified. +// +// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i, + TF_Tensor** tensor, TF_Status* status); + +// Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but +// TF_OK, ctx is left unmodified. +// +// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i, + const TF_Tensor* tensor, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index e706c7c1d96ee1781d8efc0f28c5e0cbcbc80861..e659ee3c3d258a626ccf03a782ec031b5a703a48 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/kernels.h" +#include "tensorflow/c/c_api.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/op.h" @@ -31,7 +32,6 @@ struct MyCustomKernel { static bool delete_called = false; static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { - LOG(INFO) << "Wow, actually got into creation"; struct MyCustomKernel* s = new struct MyCustomKernel; s->created = true; s->compute_called = false; @@ -51,12 +51,31 @@ static void MyDeleteFunc(void* kernel) { delete s; } +namespace tensorflow { + +static std::unique_ptr GetFakeKernel(const char* device_name, + const char* op_name, + Status* status) { + NodeDef def; + def.set_op(op_name); + def.set_device(device_name); + def.add_input("input1"); + def.add_input("input2"); + return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1, + status); +} + // Tests registration of a single C kernel and checks that calls through the // C/C++ boundary are being made. TEST(TestKernel, TestRegisterKernelBuilder) { const char* kernel_name = "SomeKernelName"; const char* op_name = "FooOp"; - const char* device_name = "barDev"; + const char* device_name = "FakeDeviceName1"; + + REGISTER_OP(op_name) + .Input("input1: double") + .Input("input2: uint8") + .Output("output1: uint8"); TF_KernelBuilder* builder = TF_NewKernelBuilder( op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); @@ -65,35 +84,120 @@ TEST(TestKernel, TestRegisterKernelBuilder) { TF_Status* status = TF_NewStatus(); TF_RegisterKernelBuilder(kernel_name, builder, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); - TF_Buffer* buf = TF_GetRegisteredKernelsForOp("FooOp", status); + TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); - ::tensorflow::KernelList list; + KernelList list; list.ParseFromArray(buf->data, buf->length); ASSERT_EQ(1, list.kernel_size()); - ASSERT_EQ("barDev", list.kernel(0).device_type()); + ASSERT_EQ(device_name, list.kernel(0).device_type()); TF_DeleteBuffer(buf); TF_DeleteStatus(status); } - REGISTER_OP("FooOp") + { + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + kernel->Compute(nullptr); + } + + ASSERT_TRUE(delete_called); +} + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool RequiresRecordingAccessedTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +TEST(TestKernel, TestInputAndOutputCount) { + const char* kernel_name = "InputOutputCounterKernel"; + const char* op_name = "BarOp"; + const char* device_name = "FakeDeviceName2"; + + REGISTER_OP(op_name) .Input("input1: double") .Input("input2: uint8") .Output("output1: uint8"); + static int num_inputs = 0; + static int num_outputs = 0; + + // A kernel whose Compute function has a side-effect of updating num_inputs + // and num_outputs. Various functions on TF_OpKernelContext are also + // exercised. + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + num_inputs = TF_NumInputs(ctx); + num_outputs = TF_NumOutputs(ctx); + + TF_Tensor* input = nullptr; + TF_Status* s = TF_NewStatus(); + TF_GetInput(ctx, 0, &input, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s); + EXPECT_EQ(123, *static_cast(TF_TensorData(input))); + TF_GetInput(ctx, -1, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + TF_GetInput(ctx, 3, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + + // Copy the input tensor to output. + TF_SetOutput(ctx, 0, input, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + + TF_SetOutput(ctx, 24, input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + + TF_DeleteStatus(s); + if (input != nullptr) { + TF_DeleteTensor(input); + } + }; + + TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr, + my_compute_func, nullptr); + + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(kernel_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + } + { - ::tensorflow::NodeDef def; - def.set_op("FooOp"); - def.set_device("bar"); - def.add_input("input1"); - def.add_input("input2"); - ::tensorflow::Status status; - std::unique_ptr<::tensorflow::OpKernel> kernel = - ::tensorflow::CreateOpKernel(::tensorflow::DeviceType("barDev"), - nullptr, nullptr, def, 1, &status); + OpKernelContext::Params p; + DummyDevice dummy_device(nullptr, false); + p.device = &dummy_device; + + Tensor t(tensorflow::uint8(123)); + + gtl::InlinedVector inputs; + // Simulate 2 inputs + inputs.emplace_back(&t); + inputs.emplace_back(); + p.inputs = &inputs; + + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); TF_EXPECT_OK(status); ASSERT_NE(nullptr, kernel.get()); - kernel->Compute(nullptr); - } - ASSERT_TRUE(delete_called); + p.op_kernel = kernel.get(); + OpKernelContext ctx(&p); + kernel->Compute(&ctx); + + ASSERT_EQ(2, num_inputs); + ASSERT_EQ(1, num_outputs); + ASSERT_EQ(123, ctx.mutable_output(0)->scalar()()); + } } + +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index e0ac7130a64d3928c39440c0e10a2d2e1990b9cd..ab1c1be344e2257721507543bc7647d4ff4becb2 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -178,7 +178,7 @@ Status GenArgMethods(const tf2xla::Config& config, TF_RETURN_IF_ERROR( AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); const string code = R"( - void set_arg{{NAME}}_data(void* data) { + void set_arg{{NAME}}_data(const void* data) { set_arg_data({{I}}, data); } {{TYPE}}* arg{{NAME}}_data() { diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index a2cdab5d1a8e72504ca11b789287d4efd07a59e9..968afad65ed6d4b5510687df484b7ce6743f6a85 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -114,7 +114,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // with dim indices specifying which value. No bounds checking is performed // on dim indices. - void set_arg0_data(void* data) { + void set_arg0_data(const void* data) { set_arg_data(0, data); } float* arg0_data() { @@ -132,7 +132,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg_myfeed_data(void* data) { + void set_arg_myfeed_data(const void* data) { set_arg_data(0, data); } float* arg_myfeed_data() { @@ -150,7 +150,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg1_data(void* data) { + void set_arg1_data(const void* data) { set_arg_data(1, data); } tensorflow::int64* arg1_data() { diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index be91ed4f432b1890c22900f293fd4196e5c9d970..15dcbb2641eca031e82db9aa58dee6a14ab0a2cc 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -76,6 +76,7 @@ cc_library( srcs = ["xla_cpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep ":flags", ":jit_compilation_passes", ":xla_device", @@ -95,6 +96,7 @@ cc_library( srcs = ["xla_gpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", @@ -104,6 +106,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 1fe612d43d10030675cf307b109e4dcc89cb2d79..c7e8d61d280a33a83c3386d8ef801018634d31ec 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -142,11 +142,22 @@ Status XlaCompileOnDemandOp::Compile( TF_RETURN_IF_ERROR(ctx->allocate_temp( device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); Notification n; + Status status; ctx->op_device_context()->CopyDeviceTensorToCPU( &device_tensor, "ConstantArgument", reinterpret_cast(ctx->device()), &host_tensor, - [&](Status status) { n.Notify(); }); + [&](Status s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); + if (!status.ok()) { + LOG(ERROR) << "Copying tensor of shape " + << device_tensor.shape().DebugString() << " from " + << ctx->device()->name() << "to CPU failed with " + << status.ToString(); + return status; + } constant_arguments[i] = host_tensor; } } @@ -189,6 +200,7 @@ Status XlaCompileOnDemandOp::Compile( std::map variable_args = GetVariables(ctx); std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( constant_arguments, variable_args, ctx, &args)); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 7df898ad12a15345f45fc96e0ec3d42b6e51731b..e9770647e7ba96cc1db026d12d5f11f52ce98d35 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -63,7 +63,19 @@ Status XlaCpuDeviceFactory::CreateDevices( options.device_ordinal = 0; options.compilation_device_name = DEVICE_CPU_XLA_JIT; options.use_multiple_streams = false; - devices->push_back(absl::make_unique(session_options, options)); + auto device = absl::make_unique(session_options, options); + + // Setting GpuDeviceInfo because eager runtime relies on the device + // context in tensorflow_gpu_device_info(). Also, + // tensorflow_gpu_device_info() == nullptr is used as an IsCPU test. + // We need XlaCpuDevice to be treated not as CPU because it allocates + // XlaTensors, not regular Tensors. + Status status = device->UseGpuDeviceInfo(); + if (!status.ok()) { + errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT); + return status; + } + devices->push_back(std::move(device)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 6e6532731e64bd42ee56aa719748988f321e0f17..1f3afe8822d441a5ce37617fe18d7767e9bc72e4 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -79,6 +79,13 @@ XlaDeviceContext::XlaDeviceContext( } } +void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, + Device* device, + Tensor* output_tensor, + StatusCallback done) const { + done(errors::Unimplemented("XLA->XLA same-device copies not implemented.")); +} + void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 1e18df197a2dd65590c5181b4dae4481dca36641..e45db989fac720df6c3458c93a6b8dbb0919f930 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -62,6 +62,9 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; xla::LocalClient* client() const { return client_; } se::Stream* stream() const { return stream_.get(); } diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index adf0f994b84d9fbf918a5b2478aa7d106853e038..927f983ba9ef23c8509523f42366c0c89c29db9f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -203,6 +203,8 @@ class XlaAssignVariableOp : public OpKernel { .HostMemory("output") \ .TypeConstraint("T"), \ ArgOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).TypeConstraint("T"), ArgOp); \ \ REGISTER_KERNEL_BUILDER(Name(kRetOp) \ .Device(DEVICE) \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 944f732b99c0924a08932eda0aedd8c815cc51d0..0191315a66f4d331e54fadc9dc6a073a05fd67ef 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -16,7 +16,10 @@ limitations under the License. // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "CUDA" (GPU) backend. +#include #include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -52,8 +55,35 @@ Status XlaGpuDeviceFactory::CreateDevices( VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); return Status::OK(); } - - for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) { + string allowed_gpus = + session_options.config.gpu_options().visible_device_list(); + std::set gpu_ids; + int num_visible_devices = platform.ValueOrDie()->VisibleDeviceCount(); + if (allowed_gpus.empty()) { + for (int i = 0; i < num_visible_devices; ++i) { + gpu_ids.insert(i); + } + } else { + // For loop below is copied from gpu/gpu_device.cc. It validates + // the visible_device_list and populates gpu_ids set. + const std::vector visible_devices = + absl::StrSplit(allowed_gpus, ','); + for (const string& platform_gpu_id_str : visible_devices) { + int32 platform_gpu_id; + if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) { + return errors::InvalidArgument( + "Could not parse entry in 'visible_device_list': '", + platform_gpu_id_str, "'. visible_device_list = ", allowed_gpus); + } + if (platform_gpu_id < 0 || platform_gpu_id >= num_visible_devices) { + return errors::InvalidArgument( + "'visible_device_list' listed an invalid GPU id '", platform_gpu_id, + "' but visible device count is ", num_visible_devices); + } + gpu_ids.insert(platform_gpu_id); + } + } + for (int i : gpu_ids) { XlaDevice::Options options; options.platform = platform.ValueOrDie(); options.device_name_prefix = name_prefix; diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 437db019a0eabe66417725148d8b121842e90479..554227f09de0ab4d9e07f199b957657f3121ff06 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -199,19 +199,17 @@ class XlaTensorBuffer : public TensorBuffer { public: XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size, Allocator* allocator) - : expected_size_(expected_size), + : TensorBuffer(const_cast(ptr)), + expected_size_(expected_size), actual_size_(actual_size), - allocator_(allocator) { - data_ = const_cast(ptr); - } + allocator_(allocator) {} ~XlaTensorBuffer() override { - if (data_) { - allocator_->DeallocateRaw(data_); + if (data()) { + allocator_->DeallocateRaw(data()); } } - void* data() const override { return data_; } size_t size() const override { return expected_size_; } TensorBuffer* root_buffer() override { return this; } @@ -231,7 +229,6 @@ class XlaTensorBuffer : public TensorBuffer { } private: - void* data_; size_t expected_size_; size_t actual_size_; Allocator* allocator_; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index bc3d60b90e58b4018f1c52b09941dedba7ef348a..093b61629cd0b04d5d8488139b8d7262b739f86d 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -408,13 +408,6 @@ tf_xla_py_test( name = "eager_test", size = "large", srcs = ["eager_test.py"], - disabled_backends = [ - # TODO(b/78199195) Support XLA CPU devices in eager runtime - "cpu", - "cpu_ondemand", - # TODO(b/78468222) Enable GPU backend - "gpu", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 25a84fb1b6609106213231db1ca1ce54da8bd960..5a0d9b9af9d55a8dee809d3cf909bce39c3b8b6c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -445,14 +445,9 @@ cc_library( ], deps = [ "//tensorflow/compiler/jit:flags", - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", + "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 1de85004a51bea464f8f0166511402e5dd85ac14..64fdbbebc65bff4ed0b965fcdd534cc9696472b6 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -18,86 +18,26 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace dump_graph { -namespace { - -struct NameCounts { - mutex counts_mutex; - std::unordered_map counts; -}; - -string MakeUniqueFilename(string name) { - static NameCounts& instance = *new NameCounts; - - // Remove illegal characters from `name`. - for (int i = 0; i < name.size(); ++i) { - char ch = name[i]; - if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') { - name[i] = '_'; - } - } - - int count; - { - mutex_lock lock(instance.counts_mutex); - count = instance.counts[name]++; - } - - string filename = name; - if (count > 0) { - absl::StrAppend(&filename, "_", count); - } - absl::StrAppend(&filename, ".pbtxt"); - return filename; -} - -string WriteTextProtoToUniqueFile( - Env* env, const string& name, const char* proto_type, - const ::tensorflow::protobuf::Message& proto) { - const string& dirname = GetDumpGraphFlags()->tf_dump_graph_prefix; - Status status = env->RecursivelyCreateDir(dirname); - if (!status.ok()) { - LOG(WARNING) << "Failed to create " << dirname << " for dumping " - << proto_type << ": " << status; - return "(unavailable)"; - } - string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name)); - status = WriteTextProto(Env::Default(), filepath, proto); - if (!status.ok()) { - LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath - << " : " << status; - return "(unavailable)"; - } - LOG(INFO) << "Dumped " << proto_type << " to " << filepath; - return filepath; -} - -} // anonymous namespace - string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", - graph_def); + return tensorflow::DumpGraphDefToFile( + name, graph_def, GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpGraphToFile(const string& name, Graph const& graph, const FunctionLibraryDefinition* flib_def) { - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - if (flib_def) { - *graph_def.mutable_library() = flib_def->ToProto(); - } - return DumpGraphDefToFile(name, graph_def); + return tensorflow::DumpGraphToFile(name, graph, flib_def, + GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef); + return tensorflow::DumpFunctionDefToFile( + name, fdef, GetDumpGraphFlags()->tf_dump_graph_prefix); } } // namespace dump_graph diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index dcd38764be94b20dd7917a039d49888d8d025d92..8bc329229648c5aced8d06c99b170803bb3a90f8 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -121,13 +121,11 @@ tf_kernel_library( ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:cholesky", "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", - "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", @@ -148,6 +146,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:sorting", + "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 4cfe946b2e6146f034867c06e996ffae42b90705..1b254e328a8c71bd81a0ec700e2af1d81a5fa67a 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" namespace tensorflow { namespace { @@ -28,9 +30,11 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = BatchDot(ctx->Input(0), ctx->Input(1), - /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, - /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); + auto result = + xla::BatchDot(MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(0), adj_x_), adj_x_), + MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(1), adj_y_), adj_y_)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index f4def11d08c31513aec5aad15187016a7294c2fd..90c0ebefb24ec2c4378782e9b15d3f57c33032a4 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" namespace tensorflow { namespace { @@ -29,7 +29,7 @@ class MatrixTriangularSolveOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = TriangularSolve( + auto result = xla::TriangularSolve( ctx->Input(0), ctx->Input(1), /*left_side=*/true, /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); ctx->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index a259da6383d461fd11b0d79096bf66aae7ddef06..06c6cc37ec90192486ba15010bfeb763a9ffb987 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -152,7 +152,12 @@ class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, - /*reduction_type=*/ctx->input_type(0)) {} + /*reduction_type=*/ctx->input_type(0)) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -180,10 +185,6 @@ class MaxPool2DOp : public MaxPoolOp { public: explicit MaxPool2DOp(OpKernelConstruction* ctx) : MaxPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); @@ -204,7 +205,12 @@ class AvgPoolOp : public PoolingOp { AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ - XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + XlaHelpers::SumAccumulationType(ctx->input_type(0))) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -241,10 +247,6 @@ class AvgPool2DOp : public AvgPoolOp { public: explicit AvgPool2DOp(OpKernelConstruction* ctx) : AvgPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); @@ -390,6 +392,11 @@ class AvgPoolGradOp : public XlaOpKernel { OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1, errors::Unimplemented( "Pooling is not yet supported on the batch dimension.")); + + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); } int num_dims() const { return num_spatial_dims_ + 2; } @@ -449,10 +456,6 @@ class AvgPool2DGradOp : public AvgPoolGradOp { public: explicit AvgPool2DGradOp(OpKernelConstruction* ctx) : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) { - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP( diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 3575e480c1ad4725efb97137313639aa21038bde..3e7a761120317ff85947559b7b2e52be9232afb7 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -17,20 +17,6 @@ filegroup( load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -cc_library( - name = "batch_dot", - srcs = ["batch_dot.cc"], - hdrs = ["batch_dot.h"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", - ], -) - cc_library( name = "broadcast", srcs = ["broadcast.cc"], @@ -52,8 +38,6 @@ cc_library( srcs = ["cholesky.cc"], hdrs = ["cholesky.h"], deps = [ - ":batch_dot", - ":triangular_solve", ":util", ":while_loop", "//tensorflow/compiler/xla:literal", @@ -63,6 +47,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/core:lib", ], ) @@ -87,7 +74,6 @@ cc_library( srcs = ["qr.cc"], hdrs = ["qr.h"], deps = [ - ":batch_dot", ":util", ":while_loop", "//tensorflow/compiler/xla:literal_util", @@ -100,6 +86,7 @@ cc_library( "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/core:lib", ], ) @@ -124,51 +111,6 @@ cc_library( ], ) -cc_library( - name = "triangular_solve", - srcs = ["triangular_solve.cc"], - hdrs = ["triangular_solve.h"], - deps = [ - ":batch_dot", - ":util", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/core:lib", - ], -) - -xla_test( - name = "triangular_solve_test", - srcs = ["triangular_solve_test.cc"], - tags = ["noasan"], # sometimes times out, http://b/78650012 - deps = [ - ":triangular_solve", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - cc_library( name = "util", srcs = ["util.cc"], @@ -187,29 +129,6 @@ cc_library( ], ) -xla_test( - name = "util_test", - srcs = ["util_test.cc"], - deps = [ - ":batch_dot", - ":util", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - cc_library( name = "while_loop", srcs = ["while_loop.cc"], diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc deleted file mode 100644 index 5400e8834cb9807f6dd71abe7789b2672e29e905..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" - -#include -#include - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace tensorflow { - -xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, bool conjugate_y, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { - return errors::InvalidArgument( - "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(x_shape), " vs. ", - xla::ShapeUtil::HumanString(y_shape)); - } - const int ndims = xla::ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to BatchedDot must have rank >= 2: ", ndims); - } - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { - return errors::InvalidArgument( - "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " vs ", - xla::ShapeUtil::HumanString(y_shape)); - } - batch_dimension_numbers.push_back(i); - } - - int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); - int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return errors::InvalidArgument( - "Dimensions ", x_inner_dim, " and ", y_inner_dim, - " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(y_shape), - " transpose: ", transpose_y); - } - - // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::IsZeroElementArray(x_shape) || - xla::ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); - } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return xla::Broadcast( - xla::ConstantLiteral(builder, - xla::LiteralUtil::Zero(x_shape.element_type())), - dimensions); - } - - if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = xla::Conj(x); - } - if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = xla::Conj(y); - } - - xla::PrecisionConfig precision_proto; - precision_proto.add_operand_precision(precision); - precision_proto.add_operand_precision(precision); - - xla::DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); - } - - return xla::DotGeneral(x, y, dot_dnums, &precision_proto); - }); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h deleted file mode 100644 index 6edd63a4d3b66c21aa4cce8c9f36eef0dc363cd8..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace tensorflow { - -// Multiplies slices of two tensors in batches. - -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. Each of the -// individual slices can optionally be transposed before multiplication by -// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each -// can be elementwise-complex-conjugated by setting the `conjugate_x` or -// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both -// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`. -// -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. -// -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: -// -// r_o = c_x if transpose_x else r_x -// c_o = r_y if transpose_y else c_y -// -// It is computed as: -// -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot( - xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, - bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index ab3d0a566839343828d176d9a46672824e425613..550ab5b05693b79e60e49577309328ac6846d3f9 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -101,10 +102,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // a[..., i, i] auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); // np.dot(row, np.swapaxes(row, -1, -2)) - auto diag_dot = BatchDot(row, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) auto l_ii = @@ -122,10 +120,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // The columns in [i, n] are zeroed out in `row`, so we just have to // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], // r.T) - auto dot = BatchDot(body_l, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); // np.dot(l[..., i+1:, :i], r.T) auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); @@ -185,9 +180,7 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); - auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto delta = BatchDot(lhs, TransposeInMinorDims(rhs), precision); auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); a = UpdateSliceInMinorDims(a, before - delta, {i, i}); } diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index aa1c0e474311328ac3811c8ca3e7e16afd4f8f3f..d6007748609fdd161cb89692a167eb7ed12fe00c 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -191,12 +191,8 @@ xla::StatusOr QRBlock( auto v_broadcast = xla::Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = - BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - vva = - BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto vva = BatchDot(v_broadcast, a, precision); + vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision); a = a - xla::Mul(tau, vva, /*broadcast_dimensions=*/batch_dim_indices); @@ -278,12 +274,9 @@ xla::StatusOr ComputeWYRepresentation( auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); // yv has shape [..., n, 1] - auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto yv = BatchDot(TransposeInMinorDims(y), v, precision); // wyv has shape [..., m, 1] - auto wyv = - BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto wyv = BatchDot(w, yv, precision); auto z = xla::Mul( -beta, v + wyv, @@ -375,23 +368,15 @@ xla::StatusOr QRDecomposition( // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = - BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - a_update = - BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto a_update = BatchDot(TransposeInMinorDims(w), a_panel, precision); + a_update = BatchDot(y, a_update, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = - BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - q_update = BatchDot(q_update, y, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto q_update = BatchDot(q_panel, w, precision); + q_update = BatchDot(q_update, TransposeInMinorDims(y), precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 804671fbc75b0a5a6e04b204822b6f084013cd8b..c0bd172d17c192435ba8ee196f9def0491c0bf5c 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -113,36 +113,6 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_RET_CHECK(start.size() == end.size()); - int64 n_minor_dims = start.size(); - - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); - - // Prepends 0s in the major dim - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + major_dims.size()); - - // Prepends the shape of the major dims. - std::vector padded_end(n_dims); - std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); - std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); - - std::vector strides(n_dims, 1); - return xla::Slice(x, padded_start, padded_end, strides); - }); -} std::vector ConcatVectors(absl::Span xs, absl::Span ys) { @@ -152,100 +122,4 @@ std::vector ConcatVectors(absl::Span xs, return output; } -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - int64 n_minor_dims = starts.size(); - TF_RET_CHECK(n_minor_dims == sizes.size()); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - sizes.size()); - auto padded_starts = PrependZerosInMajorDims(x, starts); - auto padded_sizes = ConcatVectors(major_dims, sizes); - return xla::DynamicSlice(x, padded_starts, padded_sizes); - }); -} - -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = xla::ConstantR1(builder, start_as_int32); - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - xla::ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return xla::DynamicUpdateSlice(x, update, start_constant); - }); -} - -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - const int64 n_minor_dims = start.size(); - TF_RET_CHECK(n_minor_dims <= n_dims); - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + (n_dims - n_minor_dims)); - return UpdateSlice(x, update, padded_start); - }); -} - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts) { - auto padded_starts = PrependZerosInMajorDims(x, starts); - return xla::DynamicUpdateSlice(x, update, padded_starts); -} - -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = xla::Reshape(xla::ConstantR0(builder, 0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); - } - return xla::ConcatInDim(builder, padded_starts, 0); - }); -} - -xla::XlaOp TransposeInMinorDims(xla::XlaOp x) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - std::vector permutation(n_dims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return xla::Transpose(x, permutation); - }); -} - -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? xla::Conj(x) : x; - }); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 80e9e5b002d49581209e608b98606e02709c5876..aec8061cb4322b8d315b6cdc80c7fff1e0cb4cb1 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -38,44 +38,10 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64 value); -// Builds a vector of zeros of length rank(x) with the last values being -// those in `starts`. -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts); - -// Performs a slice in the minor dimensions of a Tensor. -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end); - // Returns the concatenation of `xs` and `ys`. std::vector ConcatVectors(absl::Span xs, absl::Span ys); -// Performs a dynamic slice in the minor dimensions of a Tensor. -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes); - -// Updates a slice of 'x', i.e., -// x[start[0], ..., start[n]] = update -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -// Updates a slice of 'x', where 'start' contains a list of minor dimensions: -// x[..., start[0], ..., start[n]] = update -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts); - -// Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::XlaOp TransposeInMinorDims(xla::XlaOp x); - -// Applies a complex conjugation operation if `a` is complex and `conjugate_a` -// is true, otherwise returns its argument. -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index a1d359e97c4fad3ca74d44a358cba0e8190cdc22..c7341cf8b9e8d7a06fd304ae8766420d20f0c16e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -206,8 +206,14 @@ class XlaCompiledCpuFunction { // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. - void set_arg_data(size_t index, void* data) { - buffer_table_[arg_index_table_[index]] = data; + void set_arg_data(size_t index, const void* data) { + // The const_cast is safe because the generated code does not write to arg + // buffers. + // + // buffer_table_ contains pointers to buffers that _will_ be written to by + // generated code so it would be misleading to make buffer_table_ a `const + // void**`. + buffer_table_[arg_index_table_[index]] = const_cast(data); } // ------------------------------ diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index dc188b67a69528693e246f073231448e60ee975d..41db8de29ff0085a30847ff41db4ffbfc774e2a1 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -110,7 +110,11 @@ cc_library( deps = [ ":arithmetic", ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/types:span", @@ -123,6 +127,7 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ ":matrix", + ":slicing", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -172,6 +177,38 @@ cc_library( ], ) +cc_library( + name = "slicing", + srcs = ["slicing.cc"], + hdrs = ["slicing.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "slicing_test", + srcs = ["slicing_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "sorting", srcs = ["sorting.cc"], @@ -221,3 +258,48 @@ cc_library( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "triangular_solve", + srcs = ["triangular_solve.cc"], + hdrs = ["triangular_solve.h"], + deps = [ + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "triangular_solve_test", + srcs = ["triangular_solve_test.cc"], + tags = ["noasan"], # sometimes times out, http://b/78650012 + deps = [ + ":triangular_solve", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 08a887a6e4660cb2528f0ec7244b7ccc540808d2..36fdda39b4124b9100c6054160f9c17bdf787d6f 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -268,17 +268,16 @@ XlaOp Digamma(XlaOp input) { // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. XlaOp RoundToEven(XlaOp x) { - auto half = xla::ScalarLike(x, 0.5); - auto one = xla::ScalarLike(x, 1.0); - auto two = xla::ScalarLike(x, 2.0); + auto half = ScalarLike(x, 0.5); + auto one = ScalarLike(x, 1.0); + auto two = ScalarLike(x, 2.0); - auto round_val = xla::Floor(x); + auto round_val = Floor(x); auto fraction = x - round_val; - auto nearest_even_int = round_val - two * xla::Floor(half * x); - auto is_odd = xla::Eq(nearest_even_int, one); - return xla::Select(xla::Or(xla::Gt(fraction, half), - xla::And(xla::Eq(fraction, half), is_odd)), - round_val + one, round_val); + auto nearest_even_int = round_val - two * Floor(half * x); + auto is_odd = Eq(nearest_even_int, one); + return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), + round_val + one, round_val); } // Trigonometric functions. @@ -320,4 +319,13 @@ XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); } XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); } +XlaOp MaybeConjugate(XlaOp x, bool conjugate) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == C64 && conjugate; + return perform_conj ? Conj(x) : x; + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 3f06d04b9ae98b3aa75e68cd07810b2b4c24d280..17612bf9fdc0f1eabb338671c93c025c5b268872 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -86,6 +86,10 @@ XlaOp Cosh(XlaOp x); // Computes the hyperbolic sine of 'x'. XlaOp Sinh(XlaOp x); +// Applies a complex conjugation operation if `a` is complex and `conjugate` +// is true, otherwise returns its argument. +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index f78d8509d857bb7723e207dfc99586ed5d3cd770..ffd744d190885b8e3f4149a48a706498b3787618 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -21,6 +21,11 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { @@ -71,7 +76,7 @@ XlaOp Triangle(XlaOp x, bool lower) { AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); auto a = Iota(builder, U32, n); auto b = Iota(builder, U32, m); - xla::XlaOp indicator; + XlaOp indicator; if (lower) { indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); } else { @@ -87,4 +92,94 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } +XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (ShapeUtil::Rank(x_shape) != ShapeUtil::Rank(y_shape)) { + return InvalidArgument( + "Arguments to BatchDot have different ranks: %s vs. %s", + ShapeUtil::HumanString(x_shape), ShapeUtil::HumanString(y_shape)); + } + const int ndims = ShapeUtil::Rank(x_shape); + if (ndims < 2) { + return InvalidArgument( + "Arguments to BatchDot must have rank >= 2: got %d", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { + return InvalidArgument( + "Dimension %d of inputs to BatchDot must be equal: shapes %s vs %s", + i, ShapeUtil::HumanString(x_shape), + ShapeUtil::HumanString(y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = ndims - 1; + int y_inner_dim = ndims - 2; + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { + return InvalidArgument( + "Dimensions %d and %d of arguments to BatchDot must be equal: " + "shapes %s vs %s", + x_inner_dim, y_inner_dim, ShapeUtil::HumanString(x_shape), + ShapeUtil::HumanString(y_shape)); + } + + // Check for zero lhs/rhs dim size. + if (ShapeUtil::IsZeroElementArray(x_shape) || + ShapeUtil::IsZeroElementArray(y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = ndims - 2; + int y_outer_dim = ndims - 1; + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); + return Broadcast( + ConstantLiteral(builder, LiteralUtil::Zero(x_shape.element_type())), + dimensions); + } + + PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); + } + + return DotGeneral(x, y, dot_dnums, &precision_proto); + }); +} + +XlaOp TransposeInMinorDims(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return Transpose(x, permutation); + }); +} + +XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) { + return transpose ? TransposeInMinorDims(x) : x; +} } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 7af3e68f20b9ece2bd34b9ba77dbc3499c49cd2a..8856f99c7a0fee8f315aac11fab392cf5536f57b 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -40,6 +40,34 @@ XlaOp UpperTriangle(XlaOp x); // Get the lower triangle part of the last two dimensions XlaOp LowerTriangle(XlaOp x); +// Multiplies slices of two tensors in batches. + +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if transpose_x else r_x +// c_o = r_y if transpose_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Transposes a stack of matrices `x` by swapping the last two dimensions. +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); + +// Transposes `x` in its minor dimensions if `transpose` is true, otherwise +// returns `x` unchanged. +xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc index 8eacc583d1fb3fe4a67ad3d32c0109d7ad6e81f5..0593a7517ac125ca8dc5395cee76f6bc23232cd3 100644 --- a/tensorflow/compiler/xla/client/lib/matrix_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -25,13 +26,13 @@ limitations under the License. namespace xla { namespace { -class NumericTest : public ClientLibraryTestBase { +class MatrixTest : public ClientLibraryTestBase { protected: template void TestMatrixDiagonal(); }; -XLA_TEST_F(NumericTest, Triangle) { +XLA_TEST_F(MatrixTest, Triangle) { XlaBuilder builder(TestName()); Array3D input(2, 3, 4); input.FillIota(0); @@ -46,7 +47,7 @@ XLA_TEST_F(NumericTest, Triangle) { } template -void NumericTest::TestMatrixDiagonal() { +void MatrixTest::TestMatrixDiagonal() { XlaBuilder builder("GetMatrixDiagonal"); Array3D input(2, 3, 4); input.FillIota(0); @@ -59,11 +60,46 @@ void NumericTest::TestMatrixDiagonal() { ComputeAndCompareR2(&builder, expected, {a_data.get()}); } -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } + +Array3D BatchedAValsFull() { + return {{ + {2, 0, 1, 2}, + {3, 6, 0, 1}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }}; +} + +XLA_TEST_F(MatrixTest, RowBatchDot) { + XlaBuilder builder(TestName()); -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } + int n = 4; -XLA_TEST_F(NumericTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } + XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + auto l_index = DynamicSliceInMinorDims( + a, {index, ConstantR0(&builder, 0)}, {1, n}); + BatchDot(l_index, TransposeInMinorDims(row)); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8c7df3ff5189c817202eaf39adb572f7e232ec2 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/slicing.h" + +namespace xla { + +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return Slice(x, padded_start, padded_end, strides); + }); +} + +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + auto start_constant = ConstantR1(builder, start_as_int32); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return DynamicUpdateSlice(x, update, start_constant); + }); +} + +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(x, update, padded_start); + }); +} + +namespace { + +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); + return output; +} + +XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span starts) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + auto zero = Reshape(ConstantR0(builder, 0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = Reshape(starts[i], {1}); + } + return ConcatInDim(builder, padded_starts, 0); + }); +} + +} // namespace + +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + auto padded_starts = PrependZerosInMajorDims(x, starts); + auto padded_sizes = ConcatVectors(major_dims, sizes); + return DynamicSlice(x, padded_starts, padded_sizes); + }); +} + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts) { + auto padded_starts = PrependZerosInMajorDims(x, starts); + return DynamicUpdateSlice(x, update, padded_starts); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h new file mode 100644 index 0000000000000000000000000000000000000000..6c482a38b5489c9fb17c3dca9ee3d2a1b8fd1890 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ + +namespace xla { + +// Updates a slice of 'x', i.e., +// x[start[0], ..., start[n]] = update +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start); + +// Performs a slice in the minor dimensions of a tensor. +// x[..., start[0]:end[0], ..., start[n]:end[n]] +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end); + +// Updates a slice of 'x', where 'start' contains a list of minor dimensions: +// x[..., start[0]:..., ..., start[n]:...] = update +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start); + +// Performs a dynamic slice in the minor dimensions of a tensor. +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes); + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc similarity index 67% rename from tensorflow/compiler/tf2xla/lib/util_test.cc rename to tensorflow/compiler/xla/client/lib/slicing_test.cc index 442fe92c34ca26cb1a854cc90da8dc034bca79bb..8d362119e01006555db0f82d02626175936e1d05 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -13,28 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" -#include -#include -#include - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/status_test_util.h" -namespace tensorflow { +namespace xla { namespace { -using UtilTest = xla::ClientLibraryTestBase; -using UtilLeftLookingTest = xla::ClientLibraryTestBase; +using SlicingTest = xla::ClientLibraryTestBase; xla::Array2D BValsRight() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; @@ -63,7 +54,7 @@ xla::Array3D BatchedAValsFull() { }}; } -XLA_TEST_F(UtilTest, Simple2dLookup) { +XLA_TEST_F(SlicingTest, Simple2dLookup) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, x, y; @@ -77,7 +68,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { xla::ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(UtilTest, Simple3dLookup) { +XLA_TEST_F(SlicingTest, Simple3dLookup) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, index; @@ -92,7 +83,7 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { {a_data.get(), index_data.get()}); } -XLA_TEST_F(UtilTest, SimpleSliceUpdate) { +XLA_TEST_F(SlicingTest, SimpleSliceUpdate) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, b, x, y; @@ -111,26 +102,5 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); } -XLA_TEST_F(UtilTest, RowBatchDot) { - xla::XlaBuilder builder(TestName()); - - int n = 4; - - xla::XlaOp a, row, index; - auto a_data = - CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); - auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, - "row", &builder, &row); - // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). - auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); - - auto l_index = DynamicSliceInMinorDims( - a, {index, xla::ConstantR0(&builder, 0)}, {1, n}); - BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true); - - ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, - {a_data.get(), row_data.get(), index_data.get()}); -} - } // namespace -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/xla/client/lib/triangular_solve.cc similarity index 63% rename from tensorflow/compiler/tf2xla/lib/triangular_solve.cc rename to tensorflow/compiler/xla/client/lib/triangular_solve.cc index 4bc3796e4fddcc5bbe13d24433cb905b0bca5910..c5a1d34cc66e6f8c1a832f8a8437163b846a5431 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -29,21 +29,20 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/math/math_util.h" -namespace tensorflow { +namespace xla { // Get the diagonal blocks of the coefficient matrix -xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(a)); - int ndims = xla::ShapeUtil::Rank(shape); - int64 n = xla::ShapeUtil::GetDimension(shape, -1); +XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a)); + int ndims = ShapeUtil::Rank(shape); + int64 n = ShapeUtil::GetDimension(shape, -1); int64 num_blocks = n / block_size; - xla::XlaOp diag_blocks; + XlaOp diag_blocks; // If the coefficient matrix is exactly the block size, we just add a // singleton dimension i.e. [..., n, n] -> [..., 1, n, n] @@ -58,13 +57,13 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { if (n > block_size) { // Construct the starting indices of the diagonal blocks auto start_indices = - Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), - xla::ConstantR0(builder, block_size)), + Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks), + ConstantR0(builder, block_size)), /*broadcast_sizes=*/{2}), /*permutation=*/{1, 0}); // Gather the diagonal blocks - xla::GatherDimensionNumbers dim_numbers; + GatherDimensionNumbers dim_numbers; dim_numbers.add_offset_dims(ndims - 1); dim_numbers.add_offset_dims(ndims); dim_numbers.add_start_index_map(ndims - 2); @@ -80,7 +79,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Pad with zeros auto last_blocks = SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); - xla::PaddingConfig config = xla::MakeNoPaddingConfig(ndims); + PaddingConfig config = MakeNoPaddingConfig(ndims); int64 padding = block_size - n % block_size; config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); @@ -89,9 +88,8 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Add a singleton dimension // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] - TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, - builder->GetShape(last_blocks)); - auto shape_dims = xla::AsInt64Slice(blocks_shape.dimensions()); + TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks)); + auto shape_dims = AsInt64Slice(blocks_shape.dimensions()); auto last_blocks_dims = std::vector(ndims); std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); @@ -100,7 +98,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Concatenate with the other blocks if necessary if (n > block_size) { diag_blocks = - xla::ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); + ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); } else { diag_blocks = last_blocks; } @@ -110,16 +108,16 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { }); } -xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, - bool transpose_a, bool conjugate_a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = diag_blocks.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { +XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, + bool conjugate_a, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = diag_blocks.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is // (..., size, size). We resize this to (num_blocks, size, size). - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(diag_blocks)); - int64 block_size = xla::ShapeUtil::GetDimension(shape, -1); - int64 num_blocks = xla::ShapeUtil::ElementsIn(shape) / + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); + int64 block_size = ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = ShapeUtil::ElementsIn(shape) / tensorflow::MathUtil::IPow(block_size, 2); diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); @@ -131,9 +129,9 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, // zero (which can happen if the last block was padded) otherwise it will // introduce nans which will propagate auto diags = GetMatrixDiagonal(diag_blocks); - TF_ASSIGN_OR_RETURN(xla::Shape diags_shape, builder->GetShape(diags)); + TF_ASSIGN_OR_RETURN(Shape diags_shape, builder->GetShape(diags)); auto one = ScalarLike(diags, 1); - auto ones = Broadcast(one, xla::AsInt64Slice(diags_shape.dimensions())); + auto ones = Broadcast(one, AsInt64Slice(diags_shape.dimensions())); diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); @@ -159,40 +157,40 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, auto start_index = (lower) ? 0 : block_size - 1; auto output_block = DynamicUpdateSlice( neg_identity, pos_one, - /*start_indices=*/xla::ConstantR1(builder, 2, start_index)); + /*start_indices=*/ConstantR1(builder, 2, start_index)); // Broadcast diag([1, -1, -1, ...]) to every block - xla::XlaOp output = Broadcast(output_block, - /*broadcast_sizes=*/{num_blocks}); + XlaOp output = Broadcast(output_block, + /*broadcast_sizes=*/{num_blocks}); // Now we construct a loop that performs matrix-vector multiplications // inverting the blocks one row at a time - std::vector tuple_shapes = { + std::vector tuple_shapes = { // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), + ShapeUtil::MakeShape(S32, {}), // The output has the shape of A, with one row updated each iteration. - xla::ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size}), + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size}), // The input is a loop invariant. - xla::ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size})}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size})}; + Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = One(builder, xla::S32); - auto init = xla::Tuple(builder, {init_i, output, scaled_diag_blocks}); + auto init_i = One(builder, S32); + auto init = Tuple(builder, {init_i, output, scaled_diag_blocks}); // Construct the loop condition function. - std::unique_ptr condb = + std::unique_ptr condb = builder->CreateSubBuilder("InvertDiagCond"); { auto i = GetTupleElement( Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); - Lt(i, xla::ConstantR0(condb.get(), block_size)); + Lt(i, ConstantR0(condb.get(), block_size)); } TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); // Construct the loop body function. - std::unique_ptr bodyb = + std::unique_ptr bodyb = builder->CreateSubBuilder("InvertDiagBody"); { auto input_tuple = @@ -202,21 +200,21 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, auto body_out = GetTupleElement(input_tuple, 1); auto body_input = GetTupleElement(input_tuple, 2); - auto zero = xla::ConstantR1(bodyb.get(), 1, 0); + auto zero = ConstantR1(bodyb.get(), 1, 0); auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; auto start_indices = - xla::ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); + ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); auto input_row = DynamicSlice(body_input, start_indices, /*slice_sizes=*/{num_blocks, 1, block_size}); // We want -L21 L11^{-1} - xla::DotDimensionNumbers dnums; + DotDimensionNumbers dnums; dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - xla::PrecisionConfig precision_proto; + PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); @@ -224,7 +222,7 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, body_out = DynamicUpdateSlice(body_out, update, start_indices); auto next_i = i + ScalarLike(i, 1); - xla::Tuple(bodyb.get(), {next_i, body_out, body_input}); + Tuple(bodyb.get(), {next_i, body_out, body_input}); } TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); @@ -238,27 +236,26 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, /*broadcast_dimensions=*/{0, 1}); // Reshape back to original batch major dimensions - return Reshape(inv_diag_blocks, xla::AsInt64Slice(shape.dimensions())); + return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions())); }); } -xla::XlaOp SolveWithInvertedDiagonalBlocks( - xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, - builder->GetShape(inv_diag_blocks)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - int64 block_size = xla::ShapeUtil::GetDimension(blocks_shape, -1); - - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - int64 ndims = xla::ShapeUtil::Rank(a_shape); - int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); +XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks)); + TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1); + + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + int64 ndims = ShapeUtil::Rank(a_shape); + int64 n = ShapeUtil::GetDimension(a_shape, -1); int64 num_blocks = n / block_size + (n % block_size != 0); int64 m_dim = (left_side) ? -1 : -2; - int64 m = xla::ShapeUtil::GetDimension(b_shape, m_dim); + int64 m = ShapeUtil::GetDimension(b_shape, m_dim); // Initialize the solution auto x = ZerosLike(b); @@ -294,7 +291,7 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( } auto b_row = SliceInMinorDims(b, start, end); - xla::XlaOp remainder; + XlaOp remainder; if (i == 0) { remainder = b_row; } else { @@ -311,29 +308,27 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( auto a_row = MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); if (left_side) { - remainder = b_row - BatchDot(a_row, x, transpose_a, false, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + remainder = + b_row - BatchDot(MaybeTransposeInMinorDims(a_row, transpose_a), x, + precision); } else { - remainder = b_row - BatchDot(x, a_row, false, transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + remainder = + b_row - BatchDot(x, MaybeTransposeInMinorDims(a_row, transpose_a), + precision); } } - xla::XlaOp x_update; - auto zero = Zero(builder, xla::S32); - auto start_index = - xla::ConstantR0WithType(builder, xla::S32, j * block_size); - std::vector update_starts = {start_index, zero}; + XlaOp x_update; + auto zero = Zero(builder, S32); + auto start_index = ConstantR0WithType(builder, S32, j * block_size); + std::vector update_starts = {start_index, zero}; if (left_side) { - x_update = - BatchDot(inv_block, remainder, transpose_a, false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + x_update = BatchDot(MaybeTransposeInMinorDims(inv_block, transpose_a), + remainder, precision); } else { - x_update = - BatchDot(remainder, inv_block, false, transpose_a, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + x_update = BatchDot(remainder, + MaybeTransposeInMinorDims(inv_block, transpose_a), + precision); std::swap(update_starts[0], update_starts[1]); } x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); @@ -343,24 +338,24 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( }); } -xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(a_shape), " vs. ", - xla::ShapeUtil::HumanString(b_shape)); +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool transpose_a, bool conjugate_a, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + if (ShapeUtil::Rank(a_shape) != ShapeUtil::Rank(b_shape)) { + return InvalidArgument( + "Arguments to TriangularSolve have shapes with different ranks: " + "%s vs. %s", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } - const int64 ndims = xla::ShapeUtil::Rank(a_shape); + const int64 ndims = ShapeUtil::Rank(a_shape); if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to TriangularSolve must have rank >= 2: ", ndims); + return InvalidArgument( + "Arguments to TriangularSolve was rank %d but must have rank >= 2.", + ndims); } // The batch dimensions must be equal. std::vector batch_dimensions; @@ -368,32 +363,33 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, int64 a_size = a_shape.dimensions(i); int64 b_size = b_shape.dimensions(i); if (a_size != b_size) { - return errors::InvalidArgument( - "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); + return InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal; " + "shapes were %s and %s.", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } batch_dimensions.push_back(a_size); } - if (xla::ShapeUtil::GetDimension(a_shape, -1) != - xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); + if (ShapeUtil::GetDimension(a_shape, -1) != + ShapeUtil::GetDimension(a_shape, -2)) { + return InvalidArgument( + "The 'a' argument to TriangularSolve must be a batched square matrix;" + " shape was: %s", + ShapeUtil::HumanString(a_shape)); } - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); + const int64 m = ShapeUtil::GetDimension(b_shape, -2); + const int64 n = ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) { + return InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes %s and " + "%s", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to TriangularSolve must be >= 1; got ", + return InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got %d", block_size); } @@ -413,4 +409,4 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, }); } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/xla/client/lib/triangular_solve.h similarity index 88% rename from tensorflow/compiler/tf2xla/lib/triangular_solve.h rename to tensorflow/compiler/xla/client/lib/triangular_solve.h index 2303234f361e54cd2a0ad495cb03b371bed76877..50a3b30ebd1c15eb6d2ace4e351cb41f21db7093 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Solves systems of linear equations with lower or upper triangular coefficient // matrices by forward- or back-substitution. Broadcasting along leading @@ -57,11 +57,11 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::XlaOp TriangularSolve( - xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, +XlaOp TriangularSolve( + XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, int64 block_size = 128, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc similarity index 99% rename from tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc rename to tensorflow/compiler/xla/client/lib/triangular_solve_test.cc index aeebf16028d40189203cdfd815f06a339ee72902..f6a70d64a788d95a456774ccbbcf67f2e5cac98b 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include #include @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" -namespace tensorflow { +namespace xla { namespace { using TriangularSolveTest = xla::ClientLibraryTestBase; @@ -330,4 +330,4 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { } } // namespace -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index aaa5d6989eefb94edb8921d13f96e3705aa3e3a4..049cd15738a619294b19d5cf74ca514d7b4a00ad 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -71,9 +71,9 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString( + ShapeUtil::HumanStringWithLayout( computation_layout.parameter_layout(i).shape()), - ShapeUtil::HumanString(arguments[i]->on_host_shape())); + ShapeUtil::HumanStringWithLayout(arguments[i]->on_host_shape())); } } diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index d7e7b9e621894f1c363734d6415a38d2e8165463..20609cad58d920c0c272899c41efeb99d23cd490 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -334,6 +334,12 @@ void AllocateFlags() { "overhead from context switching but we let the user override this " "behavior to help run tests on the host that run models in parallel " "across multiple devices."), + tensorflow::Flag( + "xla_gpu_disable_ptxas_optimizations", + bool_setter_for( + &DebugOptions::set_xla_gpu_disable_ptxas_optimizations), + flag_values->xla_gpu_disable_ptxas_optimizations(), + "In XLA:GPU run ptxas in -O0 (default is -O3)."), }); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index 12b7094705e75305dc43a013576f4549dd5f4185..267701e9c0e42a21d2cda6238520f6a9692e7e76 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -31,3 +31,5 @@ upper_tabs: - title: XLA compile API path: /xla/tutorials/xla_compile status: experimental + +- include: /_upper_tabs_right.yaml diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index c1381779c91fab01e6a5c5e5658fb87c5c0a8ddc..6e2ee866321a070d55a7221c7c68024ceaa93448 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -314,7 +314,50 @@ CompiledLocalComputation::CompiledLocalComputation( std::unique_ptr executable) : executable_(std::move(executable)) {} -StatusOr CompiledLocalComputation::Execute( +StatusOr CompiledLocalComputation::Execute( + absl::Span argument_handles) { + LocalClient* client = GetOrCreateLocalClient(); + StatusOr device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(0); + StatusOr result_buffer_status; + if (!device_ordinal_status.ok()) { + result_buffer_status = device_ordinal_status.status(); + } else { + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica 0 mapped to device ordinal for execution: " + << device_ordinal; + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + argument_buffers.push_back(handle->shaped_buffer()); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(1, /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + + result_buffer_status = executable_->Run(argument_buffers, options); + } + + if (!result_buffer_status.ok()) { + return InternalError( + "Failed running replica 0 (other replicas may have failed as well): " + "%s.", + result_buffer_status.status().ToString()); + } + return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); +} + +StatusOr CompiledLocalComputation::ExecutePerReplica( absl::Span> argument_handles) { LocalClient* client = GetOrCreateLocalClient(); const int num_replicas = GetReplicaCount(); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 8247b057754ffe035e752662f175a13440468ec1..149e44570df5c6a3df88bbe2ffa779be47842d82 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -173,7 +173,13 @@ class CompiledLocalComputation { public: CompiledLocalComputation(std::unique_ptr executable); - StatusOr Execute( + StatusOr Execute( + absl::Span argument_handles); + + // Execute on many replicas. Takes a sequence of argument lists (one argument + // list per replica) and returns a tuple of results (one result per replica). + // The number of argument lists must be equal to the replica count. + StatusOr ExecutePerReplica( absl::Span > argument_handles); private: diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 88349b788eae1a0230b0a38ca512785eb6bfe63c..d23d693c1e5bde43b52959e4397aa311268411bb 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -1029,6 +1029,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::XrtAllocationTuple::size; %unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::CompiledLocalComputation::Execute; +%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; %unignore xla::swig::CompiledXrtComputation; %unignore xla::swig::CompiledXrtComputation::Execute; %unignore xla::swig::LocalComputation; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 4dbbeebd0b0b1a98023eec8ef2adf603975656a8..c91a2aaf56dfe2127168628c78e0c4b868a28055 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -588,9 +588,16 @@ class LocalComputation(object): compile_options=compile_options, layout_fn=layout_fn) - def Execute(self, arguments=()): + def GetReturnValueShape(self): + return _wrap_shape(self._c_computation.GetReturnValueShape()) + + def Execute(self, arguments=(), check_for_deleted_args=True): """Execute on one replica with LocalBuffer arguments and return value.""" - return self.ExecutePerReplica([arguments])[0] + if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): + raise ValueError('Executing with deleted local buffer argument') + raw_args = [arg.c_buffer for arg in arguments] + output_buffer = self._c_computation.Execute(raw_args) + return LocalBuffer(output_buffer, backend=self._backend, replica=0) def ExecutePerReplica(self, arguments=None): """Execute on many replicas with LocalBuffer arguments and return value. @@ -633,7 +640,7 @@ class LocalComputation(object): 'Multi-replica execution is not yet supported via the XRT backend.') output_buffers = [self._c_computation.Execute(stripped_args[0])] else: - output_buffer_tup = self._c_computation.Execute(stripped_args) + output_buffer_tup = self._c_computation.ExecutePerReplica(stripped_args) size = output_buffer_tup.size() output_buffers = [output_buffer_tup.Release(i) for i in xrange(size)] diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 81e71eee5208cf488269045b70989e062cb83473..4c21ae2a427477caa86fb4130616c38eb3bcf006 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1910,6 +1910,41 @@ cc_library( ], ) +cc_library( + name = "dynamic_dimension_inference", + srcs = ["dynamic_dimension_inference.cc"], + hdrs = ["dynamic_dimension_inference.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "dynamic_dimension_inference_test", + srcs = ["dynamic_dimension_inference_test.cc"], + deps = [ + ":dynamic_dimension_inference", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "reshape_mover_test", srcs = ["reshape_mover_test.cc"], @@ -2067,7 +2102,8 @@ tf_cc_test( srcs = ["hlo_computation_test.cc"], deps = [ ":hlo", - ":hlo_matchers", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2661,7 +2697,6 @@ tf_cc_test( ":algebraic_simplifier", ":computation_layout", ":hlo", - ":hlo_matchers", ":layout_assignment", ":pattern_matcher", ":pattern_matcher_gmock", @@ -2675,6 +2710,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index a348bcf0a232994a046df51563a9167faac08190..985c5af1c4d89425dd6693585e42e22510fe21f8 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include +#include #include #include #include @@ -68,6 +69,45 @@ bool IsAll(const HloInstruction* op, int8 value) { } } +// Checks whether `op` is a floating-point constant or broadcast of a constant +// of the form +/- 2^k for some integer k positive, negative, or zero. Such +// values are interesting because multiplying by a power of 2 just moves the +// exponent. +bool IsAllFpConstantPowerOf2(const HloInstruction* op) { + // Unwrap the broadcast if necessary. + const HloInstruction* c; + if (!Match(op, m::ConstantEffectiveScalar(&c)) && + !Match(op, m::Broadcast(m::Constant(&c).WithShape( + m::Shape().IsEffectiveScalar())))) { + return false; + } + auto val = [&]() -> absl::optional { + switch (c->shape().element_type()) { + case BF16: + return static_cast(c->literal().GetFirstElement()); + case F16: + return static_cast(c->literal().GetFirstElement()); + case F32: + return c->literal().GetFirstElement(); + case F64: + return c->literal().GetFirstElement(); + default: + // Cowardly refuse to consider complex types. + return absl::nullopt; + } + }(); + if (!val) { + return false; + } + + int exp; + double mantissa = std::frexp(*val, &exp); + // frexp returns a value in the range (-1; -0.5] U [0.5, 1). A return value + // of +/-0.5 therefore indicates that the floating point value is a power of + // 2. + return mantissa == 0.5 || mantissa == -0.5; +} + // Returns whether the given transpose produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. bool TransposeIsBitcast(const HloInstruction* transpose) { @@ -415,6 +455,40 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { sum_of_constants)); } + // A*C + B*C => (A+B)*C + // + // - If A, B, and C are integers, do this unconditionally. Proof of + // correctness: https://rise4fun.com/Alive/u9X. + // + // - If A, B, and C are floating point, do this if C is a scalar constant or + // broadcast of scalar constant and is equal to +/- 2^k for some (possibly + // negative) integer k. + // + // Multiplying by a power of 2 just moves the exponent, so our answer is + // exact modulo rounding of intermediate results so long as + // + // - none of the three products has an exponent which underflows (so the + // result is 0 or denormal), and + // - none of the three products overflows to inf. + // + // Proof: See algebraic_simplifier_proof_distributive_property.py. + // + // We deem these differences in rounding, underflow, and overflow + // acceptable in the ML context. + HloInstruction *b, *c; + if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) || + (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && + (ShapeUtil::ElementIsIntegral(add->shape()) || + IsAllFpConstantPowerOf2(c))) { + return ReplaceWithNewInstruction( + add, HloInstruction::CreateBinary( + add->shape(), HloOpcode::kMultiply, + computation_->AddInstruction(HloInstruction::CreateBinary( + add->shape(), HloOpcode::kAdd, a, b)), + c)); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py new file mode 100644 index 0000000000000000000000000000000000000000..5da13da041b4ded813876af7ca379025187545ab --- /dev/null +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Proof that transforming (A*C)+(B*C) <=> (A+B)*C is "safe" if C=2^k. + +Specifically, for all floating-point values A, B, and C, if + + - C is equal to +/- 2^k for some (possibly negative) integer k, and + - A, B, C, A*C, B*C, and A+B are not subnormal, zero, or inf, + +then there exists a rounding mode rm in [RTZ, RNE] such that + + (A*C) + (B*C) == (A+B) * C (computed with rounding mode rm). + +Informally, this means that the equivalence holds for powers of 2 C, modulo +flushing to zero or inf, and modulo rounding of intermediate results. + +Requires z3 python bindings; try `pip install z3-solver`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import z3 + +# We do float16 because it lets the solver run much faster. These results +# should generalize to fp32 and fp64, and you can verify this by changing the +# value of FLOAT_TY (and then waiting a while). +FLOAT_TY = z3.Float16 + +a = z3.FP("a", FLOAT_TY()) +b = z3.FP("b", FLOAT_TY()) +c = z3.FP("c", FLOAT_TY()) + +s = z3.Solver() + +# C must be a power of 2, i.e. significand bits must all be 0. +s.add(z3.Extract(FLOAT_TY().sbits() - 1, 0, z3.fpToIEEEBV(c)) == 0) + +for rm in [z3.RTZ(), z3.RNE()]: + z3.set_default_rounding_mode(rm) + before = a * c + b * c + after = (a + b) * c + + # Check that before == after, allowing that 0 == -0. + s.add( + z3.Not( + z3.Or( + before == after, # + z3.And(z3.fpIsZero(before), z3.fpIsZero(after))))) + + for x in [ + (a * c), + (b * c), + (a + b), + ]: + s.add(z3.Not(z3.fpIsSubnormal(x))) + s.add(z3.Not(z3.fpIsZero(x))) + s.add(z3.Not(z3.fpIsInf(x))) + +if s.check() == z3.sat: + m = s.model() + print("Counterexample found!") + print(m) + print("a*c: ", z3.simplify(m[a] * m[c])) + print("b*c: ", z3.simplify(m[b] * m[c])) + print("a+b: ", z3.simplify(m[a] + m[b])) + print("a*c + b*c: ", z3.simplify(m[a] * m[c] + m[b] * m[c])) + print("(a+b) * c: ", z3.simplify((m[a] + m[b]) * m[c])) +else: + print("Proved!") diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 628d805f1c390d2c1b6571b35293adc36a95441a..14ce519b6a0fd221070006d336d23bddeb6cd621 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -80,6 +80,128 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[8] parameter(0) + p1 = s32[8] parameter(1) + p2 = s32[8] parameter(2) + x = s32[8] multiply(p0, p2) + y = s32[8] multiply(p1, p2) + ROOT sum = s32[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2)))); +} + +// A*C + B*C => (A+B)*C if C is a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.125) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::ConstantScalar(0.125)))); +} + +// A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + c = f32[] constant(0.125) + b = f32[4] broadcast(c), dimensions={} + x = f32[4] multiply(p0, b) + y = f32[4] multiply(p1, b) + ROOT sum = f32[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + +// A*C + B*C => (A+B)*C simplification should not happen if C is not a +// floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.3) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are +// complex numbers. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = c64[8] parameter(0) + p1 = c64[8] parameter(1) + p2 = c64[8] parameter(2) + x = c64[8] multiply(p0, p2) + y = c64[8] multiply(p1, p2) + ROOT sum = c64[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = bf16[4] parameter(0) + p1 = bf16[4] parameter(1) + c = bf16[] constant(0.125) + b = bf16[4] broadcast(c), dimensions={} + x = bf16[4] multiply(p0, b) + y = bf16[4] multiply(p1, b) + ROOT sum = bf16[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { auto m = CreateNewVerifiedModule(); diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index c11452a6fbd49a1fc382d11d24a7d7b7eeab0bcc..362bc44a1cf377b51c5519c6ab5e0d9628e80e58 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -60,7 +60,7 @@ absl::optional MatchesArCrsPattern( absl::optional ArCrsCombiner::WhileFromBodyParameter( HloInstruction* instruction) { - CHECK(HloOpcode::kParameter == instruction->opcode()); + CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); HloComputation* computation = instruction->parent(); auto caller_instructions = call_graph_->GetComputationCallers(computation); if (caller_instructions.size() == 1) { @@ -120,7 +120,7 @@ bool ArCrsCombiner::TupleElementsComputeSameValue( return false; } for (auto tuple : tuples) { - CHECK(tuple->opcode() == HloOpcode::kTuple); + CHECK_EQ(tuple->opcode(), HloOpcode::kTuple); if (!InstructionsComputeSameValue(tuple->mutable_operand(i1), tuple->mutable_operand(i2), visited_pairs)) { @@ -160,13 +160,6 @@ bool ArCrsCombiner::InstructionsComputeSameValue( if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) { return false; } - if (opcode1 == HloOpcode::kConstant || i1->IsCrossModuleAllReduce()) { - return i1->Identical( - *i2, - /*eq_operands=*/std::equal_to(), - /*eq_computations=*/std::equal_to(), - /*layout_sensitive=*/false); - } visited_pairs->emplace(min_uid, max_uid); for (int i = 0; i < operands1.size(); ++i) { auto operand1 = operands1[i]; @@ -175,14 +168,28 @@ bool ArCrsCombiner::InstructionsComputeSameValue( return false; } } + if (opcode1 == HloOpcode::kParameter) { + // In the general case, we don't try to prove equality of parameters. + // We only try in the context of get-tuple-element + // (see TupleElementsComputeSameValue). + return false; + } if (opcode1 == HloOpcode::kGetTupleElement) { - if (i1->tuple_index() == i2->tuple_index()) { - return true; - } - return TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), + return i1->tuple_index() == i2->tuple_index() || + TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), i2->tuple_index(), visited_pairs); } - return true; + // Don't check that the operands are identical, because Identical can + // return false for instructions that compute the same value but are not + // identical, which we don't want. We have checked the arguments with + // InstructionsComputeSameValue earlier. + auto eq_instructions = [](const HloInstruction* i1, + const HloInstruction* i2) -> bool { return true; }; + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + return i1->Identical(*i2, eq_instructions, eq_computations, + /*layout_sensitive=*/false); } void ArCrsCombiner::GroupAllReducesById(HloModule* module) { @@ -203,12 +210,12 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { auto instr_0 = instruction_vec[0]; auto add_0 = instr_0->users()[0]->users()[0]; - CHECK(HloOpcode::kAdd == add_0->opcode()); + CHECK_EQ(HloOpcode::kAdd, add_0->opcode()); for (int i = 1; i < instruction_vec.size(); ++i) { auto instr_i = instruction_vec[i]; auto add_i = instr_i->users()[0]->users()[0]; - CHECK(HloOpcode::kAdd == add_i->opcode()); + CHECK_EQ(HloOpcode::kAdd, add_i->opcode()); absl::flat_hash_map visited_pairs; if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) { all_reduce_map_.erase(it.first); @@ -242,30 +249,22 @@ StatusOr ArCrsCombiner::RewriteGraph() { HloInstruction* other_summand = (add->operands()[0] == convert) ? add->operands()[1] : add->operands()[0]; - // Remove the AllReduce and replace the CRS with: - // AllReduce - (other_summand * (num_spatial_partitions_ - 1)) + // To move the AR past the addition, we need to divide other_summand by + // the number of spatial partitions. + CHECK_EQ(all_reduce->user_count(), 1); TF_CHECK_OK( all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0))); - crs->set_all_reduce_id(all_reduce->all_reduce_id()); - auto new_shape = crs->shape(); - HloInstruction* to_subtract; - if (num_spatial_partitions_ == 2) { - to_subtract = other_summand; - } else { - Literal partitions_minus_1_lit = Literal(new_shape); - partitions_minus_1_lit.PopulateWithValue( - num_spatial_partitions_ - 1); - auto partitions_minus_1_const = parent_computation->AddInstruction( - HloInstruction::CreateConstant(partitions_minus_1_lit.Clone())); - to_subtract = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - new_shape, HloOpcode::kMultiply, other_summand, - partitions_minus_1_const)); - } - auto sub = + auto shape = other_summand->shape(); + Literal lit(shape); + lit.PopulateWithValue(num_spatial_partitions_); + auto divisor = parent_computation->AddInstruction( + HloInstruction::CreateConstant(lit.Clone())); + auto division = parent_computation->AddInstruction(HloInstruction::CreateBinary( - new_shape, HloOpcode::kSubtract, crs, to_subtract)); - TF_CHECK_OK(crs->ReplaceAllUsesWith(sub)); + shape, HloOpcode::kDivide, other_summand, divisor)); + TF_CHECK_OK(other_summand->ReplaceUseWith(add, division)); + // The AllReduce and the CRS are combined to an all-core AllReduce. + crs->set_all_reduce_id(all_reduce->all_reduce_id()); TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); } } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 9d5eaf63ccf32cd78b8c11f12f9bccdfd1fec3e0..10171835d83c75fef091a34b8fe102d263211307 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -48,6 +48,43 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } +TEST_F(ArCrsCombinerTest, SameValueTestBasecase2) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + ROOT %tuple = (f32[], f32[]) tuple(%x, %x) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase3) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[], y: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %tuple = (f32[], f32[]) tuple(%x, %y) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) { const char* module_str = R"( HloModule foobar @@ -69,6 +106,46 @@ ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[0:1]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesDontMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[1:2]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) { const char* module_str = R"( HloModule foobar @@ -320,11 +397,15 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { ArCrsCombiner combiner(2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Subtract(op::CrossReplicaSum(), op::Constant()), - op::Subtract(op::CrossReplicaSum(), op::Constant()))); - auto sub = module->entry_computation()->root_instruction()->operands()[0]; - auto crs_after = sub->operands()[0]; + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::CrossReplicaSum(op::Add( + op::Divide(op::Constant(), op::Constant()), op::Convert())), + op::CrossReplicaSum(op::Add( + op::Divide(op::Constant(), op::Constant()), op::Convert())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size()); for (int i = 0; i < replica_groups_before.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 07d6680a7232e4541554cbec5c84f5b9bcfd54e2..95c7724c3c93507ae61a984301ecfc0111bef192 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -205,38 +205,6 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { // If the code generator handles depthwise separable convolutions // inherently, then no filter expansion is needed. if (!filter_expansion_ && depthwise_separable) { - const int64 old_kernel_input_feature_dimension = - dim_numbers.kernel_input_feature_dimension(); - const int64 old_kernel_output_feature_dimension = - dim_numbers.kernel_output_feature_dimension(); - - // For depthwise convolutions, we want the kernel input feature dimension - // to be smaller than the output feature dimension. If that's not the - // case, we swap the dimensions. - if (old_kernel_input_feature_dimension > - old_kernel_output_feature_dimension) { - Shape reshaped_filter_shape = filter->shape(); - auto& dimensions = *reshaped_filter_shape.mutable_dimensions(); - std::swap(dimensions[old_kernel_input_feature_dimension], - dimensions[old_kernel_output_feature_dimension]); - - auto reshaped_filter = - add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); - - dim_numbers.set_kernel_input_feature_dimension( - old_kernel_output_feature_dimension); - - dim_numbers.set_kernel_output_feature_dimension( - old_kernel_input_feature_dimension); - - auto new_convolution = HloInstruction::CreateConvolve( - convolution->shape(), convolution->mutable_operand(0), - reshaped_filter, group_count, convolution->window(), dim_numbers, - convolution->precision_config()); - - TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - convolution, std::move(new_convolution))); - } return Status::OK(); } // We want to repeat 'filter' in the 'input_feature_dim' dimension diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index efccadedf27181a4cddf4f1dc3610f7c6db1d821..bd6868d397fbcc34fd720a09ee1f08353dd4dfab 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -296,6 +296,9 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(sin, double (*)(double)); #ifdef __APPLE__ REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); + registry->Register("__sincosf_stret", + reinterpret_cast(__sincosf_stret)); + registry->Register("__sincos_stret", reinterpret_cast(__sincos_stret)); #else REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); #endif @@ -311,6 +314,12 @@ bool RegisterKnownJITSymbols() { registry->Register("memcpy", reinterpret_cast(memcpy)); registry->Register("memmove", reinterpret_cast(memmove)); registry->Register("memset", reinterpret_cast(memset)); + +#ifdef __APPLE__ + registry->Register("memset_pattern16", + reinterpret_cast(memset_pattern16)); +#endif + return true; } diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d0472689bf48092ceef2e9792c1358687d707ec --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -0,0 +1,459 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +namespace { +bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { + return window_dimension.size() == 1 && window_dimension.stride() == 1 && + window_dimension.padding_low() == 0 && + window_dimension.padding_high() == 0 && + window_dimension.window_dilation() == 1 && + window_dimension.base_dilation() == 1; +} +} // namespace + +class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { + public: + explicit DynamicDimensionInferenceVisitor( + const DynamicParameterBinding& param_bindings, + DynamicDimensionInference* parent) + : param_bindings_(param_bindings), parent_(parent) {} + + Status DefaultAction(HloInstruction* hlo) override; + + static Status Run(HloComputation* computation, + const DynamicParameterBinding& param_bindings, + DynamicDimensionInference* parent) { + DynamicDimensionInferenceVisitor visitor(param_bindings, parent); + return computation->Accept(&visitor); + } + + Status HandleParameter(HloInstruction* hlo) override; + + Status HandleReduce(HloInstruction* hlo) override; + + Status HandleDot(HloInstruction* hlo) override; + + Status HandleTranspose(HloInstruction* hlo) override; + + Status HandleReshape(HloInstruction* hlo) override; + + Status HandlePad(HloInstruction* hlo) override; + + Status HandleBroadcast(HloInstruction* hlo) override; + + Status HandleGetDimensionSize(HloInstruction* hlo) override; + + Status HandleSelect(HloInstruction* hlo) override; + + Status HandleConvolution(HloInstruction* hlo) override; + + Status HandleReduceWindow(HloInstruction* hlo) override; + + Status HandleSelectAndScatter(HloInstruction* hlo) override; + + Status HandleGetTupleElement(HloInstruction* hlo) override; + + Status HandleElementwiseUnary(HloInstruction* hlo) override; + + Status HandleElementwiseBinary(HloInstruction* hlo) override; + + private: + using OperandDynamicDimensionFn = std::function; + + Status ForEachOperandDynamicDimension(HloInstruction* inst, + const OperandDynamicDimensionFn&); + + // Pass through a dynamic dimension from the input to the output with the same + // value and index in the shape. This is a helper function to handle trivial + // instructions like elementwise operations. + Status PassThroughDynamicDimension(HloInstruction*); + + // The dynamic parameter bindings of this computation. + const DynamicParameterBinding& param_bindings_; + + // A pointer to DynamicDimensionInference, used to update the dynamic mapping. + DynamicDimensionInference* parent_; +}; + +Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + return UnimplementedStrCat( + "Asked to propagate a dynamic dimension from hlo ", + operand->ToString(), "@", index.ToString(), "@", dimension, + " to hlo ", hlo->ToString(), ", which is not implemented."); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleGetTupleElement( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + if (hlo->tuple_index() == index[0]) { + ShapeIndex new_index = + ShapeIndexView(index).ConsumeFront().ToShapeIndex(); + parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size); + } + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + int64 broadcast_dim = hlo->dimensions(dimension); + parent_->SetDynamicSize(hlo, index, broadcast_dim, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + if (operand_index != 0) { + return Unimplemented( + "Dynamic dimension on padding value is not supported"); + } + const PaddingConfig_PaddingConfigDimension& padding_config = + hlo->padding_config().dimensions(dimension); + if (padding_config.interior_padding() == 0 && + padding_config.edge_padding_low() == 0 && + padding_config.edge_padding_high() == 0) { + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); + return Status::OK(); + } else { + return Unimplemented( + "Dynamic dimension propagation on padding dimension is not " + "supported."); + } + }); +} + +Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reduce = hlo; + int64 operand_count = reduce->operand_count(); + CHECK_EQ(operand_count % 2, 0); + if (operand_index >= operand_count / 2) { + // Init values doesn't have dynamic size. + return Status::OK(); + } + if ((absl::c_count(reduce->dimensions(), dimension) != 0)) { + // Dimension is to be reduce, stop tracing. + return Status::OK(); + } + + // Find out the new dynamic dimension after reduce. + int64 dimensions_not_reduced_count = 0; + for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + if (dimension == i) { + parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count, + dynamic_size); + + return Status::OK(); + } + if (absl::c_count(reduce->dimensions(), i) == 0) { + dimensions_not_reduced_count++; + } + } + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* dot = hlo; + const DotDimensionNumbers& dimension_numbers = + dot->dot_dimension_numbers(); + // A map from the operand dimensions to result dimension. + absl::flat_hash_map result_dim_mapping; + int64 current_result_dims = 0; + std::unordered_set batch_dims( + dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); + + for (int64 i : dimension_numbers.rhs_batch_dimensions()) { + result_dim_mapping[i] = current_result_dims++; + } + + for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(0)->shape()); i++) { + if (!absl::c_linear_search( + dimension_numbers.lhs_contracting_dimensions(), i)) { + if (operand_index == 0) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } + } + + for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(1)->shape()); i++) { + if (!absl::c_linear_search( + dimension_numbers.rhs_contracting_dimensions(), i) && + !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), + i)) { + if (operand_index == 1) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } + } + + // Check if the operand dim is in the result shape. If so, add another + // work item to trace that dimension. + auto iter = result_dim_mapping.find(dimension); + if (iter != result_dim_mapping.end()) { + parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size); + } + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension], + dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleConvolution( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* conv = hlo; + const ConvolutionDimensionNumbers& dimension_numbers = + conv->convolution_dimension_numbers(); + + if (operand_index == 0) { + if (dimension == dimension_numbers.input_batch_dimension()) { + parent_->SetDynamicSize(conv, {}, + dimension_numbers.output_batch_dimension(), + dynamic_size); + return Status::OK(); + } + + if (dimension == dimension_numbers.input_feature_dimension()) { + return Status::OK(); + } + } else { + if (dimension == dimension_numbers.kernel_input_feature_dimension()) { + return Status::OK(); + } + } + + return Unimplemented("Dynamic Spatial Convolution is not supported: %s", + conv->ToString()); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize( + HloInstruction*) { + // Dynamic dimension doesn't propagate through GetDimensionSize: + // + // Input: F32[x, y, z] + // | + // GetDimensionSize(1): U32[] + // + // The returned value is a scalar, which doesn't have any dynamic dimension in + // the shape (although the value contains the real size of the dynamic + // dimension of the input). + return Status::OK(); +} + +Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary( + HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary( + HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reshape = hlo; + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(operand->shape(), + reshape->shape()); + for (auto& unmodified : unmodified_dims) { + if (unmodified.first == dimension) { + parent_->SetDynamicSize(reshape, {}, unmodified.second, + dynamic_size); + return Status::OK(); + } + } + return Unimplemented( + "Dynamic Reshape on modified dimensions is yet not supported: %s", + reshape->ToString()); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleReduceWindow( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reduce_window = hlo; + const WindowDimension& window_dimension = + reduce_window->window().dimensions(dimension); + + if (!IsTrivialWindowDimension(window_dimension)) { + return Unimplemented( + "Dynamic Spatial reduce window is not supported: %s", + reduce_window->ToString()); + } + + parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* select_and_scatter = hlo; + const WindowDimension& window_dimension = + select_and_scatter->window().dimensions(dimension); + + if (!IsTrivialWindowDimension(window_dimension)) { + return Unimplemented( + "Dynamic Spatial select and scatter is not supported: %s", + select_and_scatter->ToString()); + } + + parent_->SetDynamicSize(select_and_scatter, {}, dimension, + dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) { + return param_bindings_.ForEachBinding( + [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter, + const DynamicParameterBinding::DynamicDimension& dynamic_dimension) { + if (dynamic_dimension.parameter_num != hlo->parameter_number()) { + return Status::OK(); + } + HloComputation* computation = hlo->parent(); + HloInstruction* target_parameter = + computation->parameter_instruction(dynamic_dimension.parameter_num); + + HloInstruction* dynamic_size = + computation->parameter_instruction(dynamic_parameter.parameter_num); + for (int64 i : dynamic_parameter.parameter_index) { + dynamic_size = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(dynamic_size->shape(), {i}), + dynamic_size, i)); + } + + parent_->SetDynamicSize(target_parameter, + dynamic_dimension.parameter_index, + dynamic_dimension.dimension, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension( + HloInstruction* inst, const OperandDynamicDimensionFn& fn) { + for (int64 operand_index = 0; operand_index < inst->operand_count(); + ++operand_index) { + auto iter = + parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index)); + if (iter != parent_->per_hlo_dynamic_dimensions_.end()) { + for (auto& dynamic_dimension : iter->second) { + HloInstruction* dynamic_size = parent_->GetDynamicSize( + dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim); + TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim, operand_index, + dynamic_size)); + } + } + } + return Status::OK(); +} + +/* static */ +StatusOr DynamicDimensionInference::Run( + HloModule* module) { + VLOG(0) << "Param Config " << module->dynamic_parameter_binding().ToString(); + DynamicDimensionInference inference(module); + TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions()); + return inference; +} + +DynamicDimensionInference::DynamicDimensionInference(HloModule* module) + : module_(module) {} + +Status DynamicDimensionInference::AnalyzeDynamicDimensions() { + return DynamicDimensionInferenceVisitor::Run( + module_->entry_computation(), module_->dynamic_parameter_binding(), this); +} + +HloInstruction* DynamicDimensionInference::GetDynamicSize( + HloInstruction* inst, const ShapeIndex& index, int64 dim) const { + auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim}); + if (iter != dynamic_mapping_.end()) { + return iter->second; + } + return nullptr; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..164d15bf111a92e3da957f609b54ee0662ef18b1 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// DynamicDimensionInference analyzes each HLO instruction in a graph and +// inferences which dimensions are dynamic and which scalar instructions +// represent the runtime real size of those dynamic dimensions. +class DynamicDimensionInference { + public: + static StatusOr Run(HloModule* module); + + string ToString() const; + + // If the dimension `dim` of instruction `inst` at `index` has a dynamic size, + // returns a scalar HloInstruction that represents the runtime size of that + // dimension. Otherwise returns nullptr. + HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index, + int64 dim) const; + + friend class DynamicDimensionInferenceVisitor; + + private: + explicit DynamicDimensionInference(HloModule* module); + + // DynamicDimension is used as a key in the dynamic key-value mapping. It + // unambiguously represents a dynamic dimension of a instruction at a given + // index. + struct DynamicDimension { + // HloInstruction that holds the dimension. + HloInstruction* inst; + // Subshape of the instruction that holds the dimension. + ShapeIndex index; + // The dimension number of the dynamic dimension at given index of a given + // instruction. + int64 dim; + + // Artifacts needed to make this struct able to be used as a `key` in absl + // maps. "friend" keywords are added so these functions can be found through + // ADL. + template + friend H AbslHashValue(H h, const DynamicDimension& m) { + return H::combine(std::move(h), m.inst, m.index, m.dim); + } + + friend bool operator==(const DynamicDimension& lhs, + const DynamicDimension& rhs) { + return lhs.inst == rhs.inst && lhs.index == rhs.index && + lhs.dim == rhs.dim; + } + }; + + // Update the dynamic mapping so that we know dimension `dim` of instruction + // `inst` at `index` has a dynamic size, and its runtime size is represented + // by a scalar instruction `size`. + void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim, + HloInstruction* size) { + dynamic_mapping_.try_emplace(DynamicDimension{inst, index, dim}, size); + auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst); + iter.first->second.emplace(DynamicDimension{inst, index, dim}); + } + + // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in + // module_. + Status AnalyzeDynamicDimensions(); + + // HloModule being analyzed. + HloModule* module_; + + // dynamic_mapping_ holds the result of the analysis. It maps a dynamic + // dimension to a scalar HloInstruction that represents the real dynamic size + // of the dynamic dimension. + using DynamicMapping = absl::flat_hash_map; + DynamicMapping dynamic_mapping_; + + using PerHloDynamicDimensions = + absl::flat_hash_map>; + PerHloDynamicDimensions per_hlo_dynamic_dimensions_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea9ebed45d99797ce4f80376ec3d0b758da3ca17 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -0,0 +1,535 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class DynamicDimensionInferenceTest : public HloTestBase { + protected: + DynamicDimensionInferenceTest() : HloTestBase() { + module_ = CreateNewVerifiedModule(); + } + + Status RunInference() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, + DynamicDimensionInference::Run(module_.get())); + + inference_ = absl::make_unique(inference); + return Status::OK(); + } + + HloComputation* GetAdd() { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + + std::unique_ptr module_; + std::unique_ptr inference_; + const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {}); +}; + +TEST_F(DynamicDimensionInferenceTest, ParamTest) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "param")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param")); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(param, {}, 1), param2); + EXPECT_EQ(inference_->GetDynamicSize(param, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(param2, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param")); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {1}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, GetTupleElement) { + // When data flows through GTE, the dynamic dimension size keeps the + // same, and the index has its front popped. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param")); + + auto gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, param, 0)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {1}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_THAT(inference_->GetDynamicSize(gte, {}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) { + // When data flows through elementwise, the dynamic dimension size keeps the + // same. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto* negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceTestI) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, negate, init, {0, 2}, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceTestII) { + // Same as ReduceTestI, but only reduce one dimension. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction( + HloInstruction::CreateReduce(reduce_shape, negate, init, {1}, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, DotTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0); + + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(1); + dnums.set_output_feature_dimension(0); + + Window window; + + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + zx_shape, a_param, b_param, /*feature_group_count=*/1, window, dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, TransposeTest) { + // Test the ability to trace unmodified dimensions + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/3, scalar_shape_, "size_param")); + + auto* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(output_shape, a_param, {2, 1, 0})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1); +} + +TEST_F(DynamicDimensionInferenceTest, ReshapeTest) { + // Test the ability to trace unmodified reshape dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}); + auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto* reshape = builder.AddInstruction( + HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 3})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 2), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 3), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 4), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ReshapeTestUnimplemented) { + // Test the ability to trace unmodified reshape dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}); + auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + Status status = RunInference(); + EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); +} + +TEST_F(DynamicDimensionInferenceTest, BroadcastTest) { + // Test the ability to trace broadcast dimension. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(output_shape, a_param, {1})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) { + // Test the ability to trace reduce window batch dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + + Window window; + // First dimension is unchanged. + WindowDimension* batch_dim = window.add_dimensions(); + batch_dim->set_size(1); + batch_dim->set_stride(1); + batch_dim->set_padding_low(0); + batch_dim->set_padding_high(0); + batch_dim->set_window_dilation(1); + batch_dim->set_base_dilation(1); + + // Second and third dimension are reduced. + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(2); + dim->set_stride(2); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + output_shape, a_param, init, window, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) { + // Test the ability to trace select and scatter batch dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + + Window window; + // First dimension is unchanged. + WindowDimension* batch_dim = window.add_dimensions(); + batch_dim->set_size(1); + batch_dim->set_stride(1); + batch_dim->set_padding_low(0); + batch_dim->set_padding_high(0); + batch_dim->set_window_dilation(1); + batch_dim->set_base_dilation(1); + + // Second and third dimension are reduced. + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(2); + dim->set_stride(2); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + output_shape, a_param, init, window, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc index d7829045cc127deaa4c2c9b705dca5285d704af2..3a09d4d4716950a09d65dd093272482d55ac5c27 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc @@ -43,13 +43,14 @@ bool IsForwardConvolutionCanonical(const HloInstruction& conv) { // dilation), returns kPad and/or kSlice instructions that explicitly apply the // padding; otherwise returns the original input operand. When there is both // positive padding (including dilation) and negative padding, we insert both -// kPad and kSlice. +// kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved +// into a kPad or kSlice op. HloInstruction* MaybePaddedAndSlicedInput( - const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums, + Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* input) { HloComputation* computation = input->parent(); - if (!window_util::HasSymmetricPadding(conv_window) || - window_util::HasBaseDilation(conv_window)) { + if (!window_util::HasSymmetricPadding(*conv_window) || + window_util::HasBaseDilation(*conv_window)) { // If padding is uneven or has dilation, we insert a kPad instruction that // applies positive padding and dilation. // @@ -62,12 +63,21 @@ HloInstruction* MaybePaddedAndSlicedInput( MakeNoPaddingConfig(input->shape().dimensions_size()); for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { int64 dim = conv_dnums.input_spatial_dimensions(i); - padding_config.mutable_dimensions(dim)->set_edge_padding_low( - std::max(0LL, conv_window.dimensions(i).padding_low())); - padding_config.mutable_dimensions(dim)->set_edge_padding_high( - std::max(0LL, conv_window.dimensions(i).padding_high())); - padding_config.mutable_dimensions(dim)->set_interior_padding( - conv_window.dimensions(i).base_dilation() - 1); + if (conv_window->dimensions(i).padding_low() > 0) { + padding_config.mutable_dimensions(dim)->set_edge_padding_low( + conv_window->dimensions(i).padding_low()); + conv_window->mutable_dimensions(i)->set_padding_low(0); + } + if (conv_window->dimensions(i).padding_high() > 0) { + padding_config.mutable_dimensions(dim)->set_edge_padding_high( + conv_window->dimensions(i).padding_high()); + conv_window->mutable_dimensions(i)->set_padding_high(0); + } + if (conv_window->dimensions(i).base_dilation() != 1) { + padding_config.mutable_dimensions(dim)->set_interior_padding( + conv_window->dimensions(i).base_dilation() - 1); + conv_window->mutable_dimensions(i)->set_base_dilation(1); + } } PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction( @@ -75,7 +85,7 @@ HloInstruction* MaybePaddedAndSlicedInput( input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } - if (window_util::HasNegativePadding(conv_window)) { + if (window_util::HasNegativePadding(*conv_window)) { // If the window has negative padding, insert a kSlice that explicitly // applies negative padding. // @@ -89,10 +99,14 @@ HloInstruction* MaybePaddedAndSlicedInput( int64 dim = conv_dnums.input_spatial_dimensions(i); // If dimension "dim" has negative padding, increase the start index or // decrement the limit index by the amount of negative padding. - start_indices[dim] += - std::max(0LL, -conv_window.dimensions(i).padding_low()); - limit_indices[dim] -= - std::max(0LL, -conv_window.dimensions(i).padding_high()); + if (conv_window->dimensions(i).padding_low() < 0) { + start_indices[dim] += -conv_window->dimensions(i).padding_low(); + conv_window->mutable_dimensions(i)->set_padding_low(0); + } + if (conv_window->dimensions(i).padding_high() < 0) { + limit_indices[dim] -= -conv_window->dimensions(i).padding_high(); + conv_window->mutable_dimensions(i)->set_padding_high(0); + } } input = @@ -140,25 +154,22 @@ bool CudnnConvPaddingLegalization::CanonicalizeForwardConvolution( // Insert slices and/or pads between the convolution and its input and/or // kernel operand. + Window new_conv_window = conv->window(); HloInstruction* new_input = MaybePaddedAndSlicedInput( - conv->window(), conv->convolution_dimension_numbers(), + &new_conv_window, conv->convolution_dimension_numbers(), conv->mutable_operand(0)); HloInstruction* new_kernel = - MaybePaddedKernel(conv->window(), conv->convolution_dimension_numbers(), + MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(), conv->mutable_operand(1)); - // Remove the padding from convolution's window field. These paddings are - // made explicit with the inserted pads. - Window new_conv_window = conv->window(); + // Remove the window dilation from convolution's window field. These paddings + // are made explicit with the pads inserted by MaybePaddedKernel(). for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) { WindowDimension* dim = new_conv_window.mutable_dimensions(i); // The size of the kernel may have changed so update the Window to match. dim->set_size(new_kernel->shape().dimensions( conv->convolution_dimension_numbers().kernel_spatial_dimensions(i))); - dim->set_padding_low(0); - dim->set_padding_high(0); - dim->set_base_dilation(1); dim->set_window_dilation(1); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 42fb38dffae31b0f4322216545027e067cab250d..33e41a2782b5932430eea621d3cea2c6634f292f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -268,5 +268,17 @@ string CudnnConvKindToString(CudnnConvKind kind) { } } +llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { + return b->CreateAnd( + b->CreateICmpEQ( + b->getInt32(0), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b)), + b->CreateICmpEQ( + b->getInt32(0), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b))); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index f373d4a8393a047aba599b0fae954e98a740161e..ebf4d926b7a280e10b09a2532caba7ad6ab3ceb2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -155,6 +155,10 @@ llvm::Value* EmitPrintf(absl::string_view fmt, llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); +// Emits code that determines whether the current thread is thread 0 within +// block 0 of the kernel. +llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index bbe1583c01167b3fbb50e066ad59a48e45f5e683..fb040aff30d48bf5817946ce53d37bc6685941e4 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -547,7 +548,91 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // TODO(b/112040122): Support variadic reduce. return Unimplemented("Variadic reduce is not supported on GPU"); } - return EmitReductionToVector(fusion); + VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); + std::vector> thunks; + absl::Span output_instructions = + root->opcode() == HloOpcode::kTuple + ? root->operands() + : absl::Span(&root, 1); + + // For multi-output fusion emit an initializer for each tuple element. + // Otherwise it's sufficient to just initialize the single output. + HloInstruction* first_reduce = nullptr; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() == HloOpcode::kReduce) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr initializer_thunk, + BuildInitializerThunk(fusion, output_instructions[i] == root + ? ShapeIndex() + : ShapeIndex({i}))); + thunks.push_back(std::move(initializer_thunk)); + first_reduce = + first_reduce == nullptr ? output_instructions[i] : first_reduce; + } + } + CHECK(first_reduce != nullptr); + std::unique_ptr kernel_thunk = + BuildKernelThunk(fusion, /*implements_whole_instruction=*/false); + GpuElementalIrEmitter elemental_emitter( + hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), + &elemental_emitter); + TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); + + // For multi-output fusion CHECK the constraints and feed all the + // reduces into a single loop code generator. Single-output reduce + // fusion is a special case of that. + InlinedVector input_gens; + InlinedVector init_value_gens; + std::vector> + extra_output_gens; + InlinedVector reducers; + InlinedVector reduce_output_shapes; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + const HloInstruction* inst = output_instructions[i]; + ShapeIndex output_shape_index; + if (root->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + if (inst->opcode() == HloOpcode::kReduce) { + CHECK(IsReductionToVector(*inst)) + << "Only reductions to vector are supported"; + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); + CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); + CHECK(first_reduce->dimensions() == inst->dimensions()); + input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); + init_value_gens.push_back( + fused_emitter.GetGenerator(inst->operand(1))); + reducers.push_back(inst->to_apply()); + reduce_output_shapes.push_back(std::move(output_shape_index)); + } else { + // For extra outputs we can relax shape equality to allow different + // types (with the same number of elements). Layouts still have to + // match. + CHECK(ShapeUtil::CompatibleIgnoringElementType( + first_reduce->operand(0)->shape(), inst->shape())); + CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), + inst->shape().layout())); + extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), + std::move(output_shape_index)); + } + } + const Shape& input_shape = first_reduce->operand(0)->shape(); + TF_CHECK_OK(EmitReductionToVector( + kernel_thunk.get(), first_reduce, input_shape, input_gens, + init_value_gens, first_reduce->dimensions(), reducers, + reduce_output_shapes, extra_output_gens)); + thunks.push_back(std::move(kernel_thunk)); + std::unique_ptr sequential_thunk = + absl::make_unique(std::move(thunks), fusion); + AddThunkToThunkSequence(std::move(sequential_thunk)); + return Status::OK(); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -617,12 +702,13 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { } Status IrEmitterUnnested::EmitExtraOutputsForReduce( - const HloInstruction* unnested_hlo, const IrArray::Index& index, + const HloInstruction* reduce, const IrArray::Index& index, absl::Span> extra_output_gens) { for (int i = 0; i != extra_output_gens.size(); ++i) { + const HloInstruction* output = reduce->parent()->FusionInstruction(); llvm::Value* extra_output_address = - GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second) + GetIrArray(*output, *output, extra_output_gens[i].second) .EmitArrayElementAddress(index, &b_, "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, @@ -632,13 +718,984 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( return Status::OK(); } +Status IrEmitterUnnested::EmitReductionToScalar( + KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> + extra_output_gens) { + // Number of elements processed by a single thread. + constexpr int64 kTileSize = 16; + int64 num_elems = ShapeUtil::ElementsIn(input_shape); + + // Round up the number of tiles to a multiple of the warp size. This is + // necessary for correctness. We launch one thread per tile, and if the + // number of threads isn't a multiple of the number of the warp size, our + // shuffles will read from inactive threads, producing undefined values. + int64 num_tiles = + RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); + + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {num_tiles}, {0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + + llvm::Type* index_ty = + GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); + + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + // Check whether every thread will process a full tile's worth of elements + // without reading outside the bounds of the input. If this is true, we can + // skip some bounds checks in the final algorithm. + bool all_threads_in_bounds = num_tiles * kTileSize == num_elems; + + // __global__ void full_reduce_kernel() { + // x_in_tiles = threadIdx.x + blockIdx.x * blockDim.x; + // x = x_in_tiles * kTileSize; + // + // partial_result = init_value; + // if (all_threads_in_bounds || x + kTileSize <= num_elems) { + // for (i = 0; i < kTileSize; ++i) { + // partial_result = Reducer(partial_result, input[x + i]); + // } + // } else { + // for (i = 0; i < kTileSize; ++i) { + // if (x + i < num_elems) { + // partial_result = Reducer(partial_result, input[x + i]); + // } + // } + // } + // for (i = warpSize / 2; i > 0; i /= 2) { + // partial_result = Reducer(partial_result, + // __shfl_down(partial_result, i)); + // } + // if (lane_id == 0) { + // AtomicReducer(&output[y], partial_result); + // } + // } + // + // // Choose num_blocks and threads_per_block such that: + // // + // // num_blocks * threads_per_block = + // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), + // // + // // and threads_per_block is a multiple of warpSize. + // reduce_kernel // + auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); + llvm::Type* element_ir_type = + llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](IrArray::Index(index_ty))); + Store(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); + } + + llvm::Value* x_in_tiles = tile_index[0]; + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); + + // Emit an inner for-loop that reduces the elements in the tile. + auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { + std::unique_ptr tile_element_loop = + llvm_ir::ForLoop::EmitForLoop( + "element_id_in_tile", index_typed_constant(0), + index_typed_constant(kTileSize), index_typed_constant(1), &b_); + + // Emit the body of the partial reduction loop. + llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), + &b_); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)), + tile_element_loop->GetIndVarValue()); + // Unless we know the tile is entirely in bounds, we have to emit a + // x-in-bounds check before reading from the input. + if (!tile_in_bounds) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_); + + // Emit code that reads the input element and accumulates it to + // the partial reduction result. + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); + } + + IrArray::Index input_index( + /*linear=*/x, input_shape, &b_); + llvm::Value* input_address = Alloca(element_ir_type); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + Store(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens); + }; + + // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's + // immediately beyond the tile. + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileSize), + NSWMul(x_in_tiles, index_typed_constant(kTileSize))); + // The tile is entirely in bound if all_threads_in_bounds or + // x_end <= num_elems. + llvm::Value* tile_in_bounds = + Or(ICmpULE(x_end, index_typed_constant(num_elems)), + b_.getInt1(all_threads_in_bounds)); + llvm_ir::LlvmIfData if_tile_in_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); + + // After the if-then-else statement on tile_in_bounds, emit calls to + // shfl_down that accumulate the partial reduction results of all threads + // from the warp. + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &b_); + int bit_width = llvm_ir::GetSizeInBits(element_ir_type); + // bitcast cannot be applied to aggregate types (even packed ones), so we + // instead bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() + ? b_.getIntNTy(bit_width) + : element_ir_type; + for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; + shuffle_distance /= 2) { + llvm::Value* result_from_other_lane = + Alloca(element_ir_type, nullptr, "result_from_other_lane"); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) + << "Requires block size a multiple of the warp size, otherwise we " + "will read undefined elements."; + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } + } + + const HloInstruction* output = + reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; + + // Emit an atomic operation that accumulates the partial reduction result of + // lane 0 (which holds the partially accumulated result for its warp) to the + // output element. + llvm::Value* lane_id = + URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); + llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); + + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + IrArray::Index( + /*linear=*/b_.getInt64(0), + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &b_), + &b_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); + }; + + // Emit a parallel loop that iterates through all input tiles, one per thread. + UpdateLaunchDimensions(launch_dimensions, kernel_thunk, + ir_emitter_context_->llvm_module()); + return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, + launch_dimensions, &b_) + .EmitLoop(IrName(reduce), index_ty); +} + +Status IrEmitterUnnested::EmitColumnReduction( + KernelThunk* kernel_thunk, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> + extra_output_gens) { + // Divide the input matrix into tiles of size KxL. For example, when the + // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like + // + // 0123 + // 0123 + // 4567 + // 4567 // Numbers indicate tile IDs. + // + // Each tile is first partially reduced to a scalar by a thread, and then the + // scalar is accumulated to the output vector using atomic operations. + // + // We choose 128 as the tile size based on empirical evidence. It's big enough + // to reduce the amount of atomic adds in the end, maximizing the memory + // bandwidth. A tile width of 2 allows for high memory bandwidth utilization + // on 16b input data. + constexpr int64 kTileHeight = 128; + constexpr int64 kTileWidth = 2; + + // If the height is not a multiple of kTileHeight, we pad the bottom of the + // input matrix. + const int64 height_in_tiles = CeilOfRatio(height, kTileHeight); + // If width is not a multiple of kTileWidth the rightmost thread will process + // fewer input elements. + const int64 width_in_tiles = CeilOfRatio(width, kTileWidth); + Shape tiled_input_shape = + ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(), + {height_in_tiles, width_in_tiles}, {1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + + // TODO(b/110211620): Convert to use i32 index_type when it is possible. + llvm::Type* index_ty = b_.getInt64Ty(); + + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; + // linear_index < height_in_tiles * width_in_tiles; + // linear_index += blockDim.x * gridDim.x) { + // y_in_tiles = linear_index / width_in_tiles; + // x_in_tiles = linear_index % width_in_tiles; + // + // partial_results[kTileWidth] = init_values; + // tile_in_y_bounds = height % kTileHeight == 0 || + // y_in_tiles * kTileHeight + kTileHeight <= height; + // tile_in_x_bounds = width % kTileWidth == 0 || + // x_in_tiles * kTileWidth + kTileWidth <= width; + // // The implementation handles y and x bound checks separately. + // if (tile_in_y_bounds && tile_in_x_bounds) { + // for (y_offset : range(kTileHeight)) { + // y = y_in_tiles * kTileHeight + y_offset; + // for (x_offset : range(kTileWidth)) { + // x = x_in_tiles * kTileWidth + x_offset; + // partial_result = Reducer(partial_result[x_offset], input[y][x]); + // } + // } + // } else { + // for (y_offset : range(kTileHeight)) { + // y = y_in_tiles * kTileHeight + y_offset; + // for (y_offset : range(kTileHeight)) { + // x = x_in_tiles * kTileWidth + x_offset; + // if (y < height && x < width) { + // partial_result = Reducer(partial_result, input[y][x]); + // } + // } + // } + // } + // for (x_offset : range(kTileWidth)) { + // AtomicReducer(&output[x + x_offset], partial_result[x_offset]); + // } + // } + auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); + // Emit the loop body that reduces one tile. + llvm::Type* element_ir_type = + llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* partial_reduction_result_address = + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + + llvm::Twine(i * kTileWidth + x_offset)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](IrArray::Index(index_ty))); + Store(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); + } + } + + // Emit an inner for-loop that partially reduces the elements in the given + // tile. + llvm::Value* y_in_tiles = tile_index[0]; + llvm::Value* x_in_tiles = tile_index[1]; + + y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty); + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); + + auto emit_tile_element_loop = [=](bool tile_in_y_bounds, + bool tile_in_x_bounds) -> Status { + std::unique_ptr tile_element_loop = + llvm_ir::ForLoop::EmitForLoop( + "element_id_in_tile", index_typed_constant(0), + index_typed_constant(kTileHeight), index_typed_constant(1), &b_); + + // Emit the body of the partial reduction loop. + llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), + &b_); + llvm::Value* y = + NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)), + tile_element_loop->GetIndVarValue()); + + // Unless we know that y is in bounds, we have to emit a check before + // reading from the input. + if (!tile_in_y_bounds) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_); + + // Emit code that reads the input element and accumulates it to + // the partial reduction result. + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); + } + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); + // Unless we know that x is in bounds, we have to emit a check before + // reading from the input. + if (!tile_in_x_bounds) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); + } + llvm::Value* input_address = Alloca(element_ir_type); + // {y,x} is an index to input_matrix_shape [height,width]. We need to + // convert that to an index to input_shape (the shape of the operand of + // "reduce"). This conversion is composed of a transposition from + // input_shape to normalized_input_shape and a reshape from + // normalized_input_shape to input_matrix_shape. + const Shape normalized_input_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); + const std::vector transpose_dimension_mapping( + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); + + const Shape input_matrix_shape = + ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), + {height, width}); + const IrArray::Index input_matrix_index({y, x}, input_matrix_shape, + &b_); + const IrArray::Index input_index = + input_matrix_index + .SourceIndexOfReshape(input_matrix_shape, + normalized_input_shape, &b_) + .SourceIndexOfTranspose(normalized_input_shape, input_shape, + transpose_dimension_mapping, &b_); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + Store(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i * kTileWidth + x_offset], + input_address}, + partial_reduction_result_addresses[i * kTileWidth + x_offset])); + TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens)); + } + } + return Status::OK(); + }; + + // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location + // that's immediately beyond the tile. + llvm::Value* y_end = + NSWAdd(index_typed_constant(kTileHeight), + NSWMul(y_in_tiles, index_typed_constant(kTileHeight))); + // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location + // that's immediately beyond the tile. + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileWidth), + NSWMul(x_in_tiles, index_typed_constant(kTileWidth))); + llvm::Value* tile_in_y_bounds = + Or(ICmpULE(y_end, index_typed_constant(height)), + b_.getInt1(height % kTileHeight == 0)); + llvm::Value* tile_in_x_bounds = + Or(ICmpULE(x_end, index_typed_constant(width)), + b_.getInt1(width % kTileWidth == 0)); + // The tile is in y bounds if "height" is a multiple of kTileHeight or + // y_end <= height. + llvm_ir::LlvmIfData if_tile_in_y_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_y_bounds, "tile_in_y_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &b_); + // The tile is in x bounds if "width" is a multiple of kTileWidth or + // x_end <= width. + llvm_ir::LlvmIfData if_tile_in_x_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, + /*tile_in_x_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, + /*tile_in_x_bounds=*/false)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, &b_); + if_tile_in_x_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, + /*tile_in_x_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, + /*tile_in_x_bounds=*/false)); + + // After the nested if-then-else statement on tile_in_y_bounds and + // tile_in_x_bounds, emit atomic operations to accumulate the partial + // reduction result to the output element. + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &b_); + const HloInstruction* output = + reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; + for (int i = 0; i != num_reduces; ++i) { + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + IrArray::Index( + x, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &b_), + &b_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i * kTileWidth + x_offset])); + } + } + return Status::OK(); + }; + + // Emit a parallel loop that iterate through all input tiles. + UpdateLaunchDimensions(launch_dimensions, kernel_thunk, + ir_emitter_context_->llvm_module()); + return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, + launch_dimensions, &b_) + .EmitLoop(IrName(reduce), index_ty); +} + +static std::pair ComputeKernelMappingSchemeForReduction( + int64 depth, int64 width, int64 kWarpSize) { + constexpr int64 kTargetNumElementsPerThread = 64; + int64 x_tile_size = kTargetNumElementsPerThread; + int64 z_tile_size = 1; + + // Only tile along the x dimension with tile size kTargetNumElementsPerThread + // if doing so doesn't require a slow version of loop with bound check on each + // dimension. A more sophisticated heuristics is to enable tile along the + // x dimension with tile size kTargetNumElementsPerThread when either width is + // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big + // enough so that only a small fraction of the threads execute the slow + // version of loop with bound check. + if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) { + x_tile_size = 8; + z_tile_size = 8; + while (depth % z_tile_size != 0) { + z_tile_size -= 1; + } + } + + return std::pair(x_tile_size, z_tile_size); +} + +Status IrEmitterUnnested::EmitRowReduction( + KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> + extra_output_gens) { + // A naive algorithm is: + // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. + // 2. Partially reduces each tile to a scalar using one thread. + // 3. Accumulates that scalar to the output vector using atomic operations. + // + // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; + // linear_index < depth * height * width_in_tiles; + // linear_index += blockDim.x * gridDim.x) { + // int x_in_tiles = linear_index % width_in_tiles; + // int y = linear_index / width_in_tiles % height; + // int z = linear_index / (height * width_in_tiles); + // float partial_result = 0; + // for (element_id_in_tile : range(x_tile_size)) { + // int x = x_in_tiles * x_tile_size + element_id_in_tile; + // if (x < width) + // partial_result = reducer(partial_result, input[z][y][x]); + // } + // AtomicReducer(&output[y], partial_result); + // } + // + // Four optimizations are performed. + // + // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 + // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead + // of making each tile consecutive, we let make tile 0 column + // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures + // that threads in a warp access consecutive memory in one iteration (i.e. + // coalesced). In the above example, the warp that contains thread 0-31 + // accesses column 0-31 in the first iteration, and 32-63 in the second + // iteration, and so on. + // + // 2. Partially accumulate partial reduced results computed by threads in the + // same warp using shfl_down. Using shfl_down is faster than directly using + // atomic operations because shfl_down transfers the data between threads + // using shared memory and threads in the same warp run in lock step (thus no + // extra synchronization needed). See + // https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ + // for details. The downside is, to produce correct results when using + // shfl_down, we need to guarantee threads in the same warp work on input + // elements with the same y, so the number of tiles in each row must be a + // multiple of 32. + // + // 3. Specialize the case that the entire tile is in bounds. When that is + // true, we don't need to emit "if(x 0; shuffle_distance /= 2) + // partial_result = Reducer( + // partial_result, + // __shfl_down_sync(CUDA_WARP_ALL, partial_result, shuffle_distance)); + // if (lane_id == 0) + // AtomicReducer(&output[y], partial_result); + // } + // + + int64 x_tile_size; + int64 z_tile_size; + std::tie(x_tile_size, z_tile_size) = + ComputeKernelMappingSchemeForReduction(depth, width, kWarpSize); + + // Round the width in tiles up to the nearest multiple of kWarpSize, so that + // the use of shfl_down is valid. + const int64 width_in_tiles = + RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), + {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + llvm::Type* index_ty = + GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); + + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + auto loop_body_emitter = [=](const IrArray::Index& tile_index) { + const int num_reduces = reducers.size(); + llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( + input_shape.element_type(), ir_emitter_context_->llvm_module()); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](IrArray::Index(index_ty))); + Store(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); + } + + llvm::Value* z_tile = tile_index[0]; + llvm::Value* y = tile_index[1]; + llvm::Value* x_tile = tile_index[2]; + + x_tile = ZExtOrTrunc(x_tile, index_ty); + + llvm::Value* warp_id = + UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); + llvm::Value* lane_id = + URem(x_tile, index_typed_constant(kWarpSize), "lane_id"); + + // The x-location of the last element in this z-x-tile. + // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); + llvm::Value* last_x = NSWAdd( + lane_id, + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(index_typed_constant(x_tile_size - 1), + NSWMul(warp_id, index_typed_constant(x_tile_size))))); + + KernelSupportLibrary ksl( + &b_, + /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, + /*prevent_vectorization=*/false); + + // Emit a for-loop that partially reduces the elements in the given + // z-x-tile. + auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, + int64 x_tile_loop_bound) -> Status { + auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { + llvm::Value* z = + NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile)); + TF_RETURN_IF_ERROR(ksl.For( + "x_tile", + /*start=*/index_typed_constant(0), + /*end=*/index_typed_constant(x_tile_loop_bound), + /*step=*/1, [&](llvm::Value* x_indvar) -> Status { + // x = lane_id + + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); + llvm::Value* x = NSWAdd( + lane_id, + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(x_indvar, + NSWMul(warp_id, llvm::ConstantInt::get( + index_ty, x_tile_size))))); + + // Unless we know the x-tile is entirely in bounds, we have to + // emit a x-in-bounds check before reading from the input. + if (!x_tile_in_bounds) { + llvm_ir::LlvmIfData if_x_in_bounds_data = + llvm_ir::EmitIfThenElse( + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", + &b_); + // Points b_ to the then-block. + llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, + &b_); + } + + // Emit code that reads the input element and accumulates it + // to the partial reduction result. + llvm::Value* input_address = Alloca(element_ir_type); + { + // {z,y,x} is an index to input_3d_tensor_shape + // [depth,height,width]. We need to convert that to an index + // to input_shape (the shape of the operand of "reduce"). + // This conversion is composed of a transposition from + // input_shape to normalized_input_shape and a reshape from + // normalized_input_shape to input_3d_tensor_shape. + const Shape normalized_input_shape = ShapeUtil:: + MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = + LayoutUtil::MinorToMajor(input_shape); + const std::vector transpose_dimension_mapping( + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); + const Shape input_3d_tensor_shape = + ShapeUtil::MakeShapeWithDescendingLayout( + input_shape.element_type(), {depth, height, width}); + const IrArray::Index input_3d_tensor_index( + {z, y, x}, input_3d_tensor_shape, &b_); + const IrArray::Index input_index = + input_3d_tensor_index + .SourceIndexOfReshape(input_3d_tensor_shape, + normalized_input_shape, &b_) + .SourceIndexOfTranspose( + normalized_input_shape, input_shape, + transpose_dimension_mapping, &b_); + + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + Store(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); + } + })); + return Status::OK(); + }; + + return ksl.For("z_tile", + /*start=*/index_typed_constant(0), + /*end=*/index_typed_constant(z_tile_size), + /*step=*/1, emit_z_tile_element_loop); + }; + + llvm::Value* tile_in_bounds = + Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), + ICmpULT(last_x, index_typed_constant(width))); + + TF_RETURN_IF_ERROR( + ksl.If(tile_in_bounds, + /*true_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true, + x_tile_size); + }, + /*false_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop( + /*x_tile_in_bounds=*/false, + CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize)); + })); + + // After accumulating the elements of the z_x_tile, emit calls to + // shfl_down that accumulate the partial reduction results of all + // threads in a warp. + int bit_width = llvm_ir::GetSizeInBits(element_ir_type); + // bitcast cannot be applied to aggregate types (even packed ones), so we + // instead bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() + ? b_.getIntNTy(bit_width) + : element_ir_type; + for (int shuffle_distance = 16; shuffle_distance >= 1; + shuffle_distance /= 2) { + llvm::Value* result_from_other_lane = + Alloca(element_ir_type, nullptr, "result_from_other_lane"); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) + << "Requires block size a multiple of the warp size, otherwise we " + "will read undefined elements."; + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } + } + + const HloInstruction* output = + reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; + + // Emit an atomic operation that accumulates the partial reduction result of + // lane 0 (which holds the partially accumulated result for its warp) to the + // output element. + llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + IrArray::Index(y, + ShapeUtil::GetSubshape( + output->shape(), reduce_output_shapes[i]), + &b_), + &b_, "output_element_address"); + // We don't need to emit atomic operations if there is only one tile of + // results. 'depth' is the z dimension, 'width' is the x dimension. + if (z_tile_size >= depth && x_tile_size >= width) { + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {output_address, partial_reduction_result_addresses[i]}, + output_address)); + } else { + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i])); + } + } + return Status::OK(); + }; + + // Emit a parallel loop that iterates through every input tiles. + UpdateLaunchDimensions(launch_dimensions, kernel_thunk, + ir_emitter_context_->llvm_module()); + return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, + launch_dimensions, &b_) + .EmitLoop(IrName(reduce), index_ty); +} + +// Figures out whether `reduce` is a row or column reduction, and which +// dimensions to reduce, and calls either `EmitRowReduction` or +// `EmitColumnReduction` as appropriate. +// Prerequisite: all the dimensions to keep are contiguous in the input layout +// and, if `reduce` is fused, the fused subgraph is pure +// elementwise. +Status IrEmitterUnnested::EmitReductionToVector( + KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span dimensions_to_reduce, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> + extra_output_gens) { + // This emission requires "reduce" to have an input layout. It is either set + // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for + // a fused kReduce). + CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " + "doesn't set the input layout of " + << reduce->ToString(); + + // Specialize multi-dimensional-array-to-vector reduction. + std::vector input_dims_to_keep; + for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); + ++input_dim) { + if (std::find(dimensions_to_reduce.begin(), dimensions_to_reduce.end(), + input_dim) == dimensions_to_reduce.end()) { + input_dims_to_keep.push_back(input_dim); + } + } + + // Sort the dimensions to keep from minor to major, to facilitate checking + // whether another dimension is major or minor of them. + std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(), + [&input_shape](int64 dim_a, int64 dim_b) { + return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), + dim_a) < + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), + dim_b); + }); + // Now, if output rank is at least 1, `input_dims_to_keep.front()` is + // minormost and `input_dims_to_keep.back()` is majormost. + + // If the dimensions to keep are minormost, emit a column reduction. As all + // the dimensions to keep are contiguous, by prerequisite of + // `EmitReductionToVector`, we only need to check whether the minormost + // dimension of the input is to keep. + if (ShapeUtil::IsEffectiveScalar(reduce->shape())) { + return EmitReductionToScalar(kernel_thunk, reduce, input_shape, input_gens, + init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); + } else if (input_dims_to_keep.front() == + LayoutUtil::Minor(input_shape.layout(), 0)) { + // Column reduction. Treat the result of "input" as a matrix whose width + // is the most minor dimension and height the product of other dimensions, + // and treat "reduce" as a column reduction of the input matrix. + const int64 width = ShapeUtil::ElementsIn(reduce->shape()); + // "width" can be zero, so don't do + // height = ShapeUtil::ElementsIn(input_shape) / width; + int64 height = 1; + for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); + ++input_dim) { + if (!std::count(input_dims_to_keep.begin(), input_dims_to_keep.end(), + input_dim)) { + height *= input_shape.dimensions(input_dim); + } + } + return EmitColumnReduction(kernel_thunk, height, width, reduce, input_shape, + input_gens, init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); + } else { + // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a + // 3D tensor. The size of dimension 1 (the height) is the size of the + // dimension to keep, the size of dimension 0 (the depth) is the product + // of dimensions that are more major than the dimension to keep, and the + // size of dimension 2 (the width) is the product of more minor + // dimensions. + int64 depth = 1; + int64 width = 1; + for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); + ++input_dim) { + if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), + input_dim) > + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), + input_dims_to_keep.back())) { + depth *= input_shape.dimensions(input_dim); + } else if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), + input_dim) < + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), + input_dims_to_keep.front())) { + width *= input_shape.dimensions(input_dim); + } + } + const int64 height = ShapeUtil::ElementsIn(reduce->shape()); + return EmitRowReduction(kernel_thunk, depth, height, width, reduce, + input_shape, input_gens, init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); + } +} + Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { // TODO(b/112040122): Support multi-output reduce. if (!ShapeUtil::IsArray(reduce->shape())) { return Unimplemented("Multi-output reduce is not supported on GPU"); } + auto input = reduce->operand(0); + auto init_value = reduce->operand(1); + absl::Span dimensions_to_reduce(reduce->dimensions()); + HloComputation* reducer = reduce->to_apply(); + // HandleReduce specializes reduction from a multi-dimensional array to a 1D + // array. The specialized version requires an initializer thunk that + // initializes the output array to the initial value of the reduce. if (IsReductionToVector(*reduce)) { - return EmitReductionToVector(reduce); + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(reduce)); + std::vector> thunks; + thunks.push_back(std::move(initializer_thunk)); + std::unique_ptr kernel_thunk = + BuildKernelThunk(reduce, /*implements_whole_instruction=*/false); + + TF_CHECK_OK(EmitReductionToVector( + kernel_thunk.get(), reduce, input->shape(), + {[&](const IrArray::Index& index) { + return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_); + }}, + {[&](const IrArray::Index& index) { + return GetIrArray(*init_value, *reduce) + .EmitReadArrayElement(index, &b_); + }}, + dimensions_to_reduce, {reducer}, {{}}, {})); + + thunks.push_back(std::move(kernel_thunk)); + + std::unique_ptr sequential_thunk = + absl::make_unique(std::move(thunks), reduce); + AddThunkToThunkSequence(std::move(sequential_thunk)); + return Status::OK(); } return IrEmitter::HandleReduce(reduce); @@ -763,7 +1820,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, index_type); - DimensionVector window_size; + std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); CHECK_GT(dim.size(), 0); @@ -2059,8 +3116,18 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_)); } - // For multioutput fusion, we need to emit each operand and the root. + // Emit the tuple pointers in one thread. We could do this at any point in + // the kernel, but we do it at the beginning in the hopes of reducing register + // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the + // kernel *anyway*. std::vector output_arrays = ConstructIrArrayForOutputs(hlo); + TF_RETURN_IF_ERROR( + KernelSupportLibrary(&b_).If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + return Status::OK(); + })); + + // For multioutput fusion, we need to emit each operand and the root. TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &b_, unroll_factor) @@ -2069,8 +3136,6 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( &hlo, launch_dimensions.launch_bound(), &b_))); b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); - return Status::OK(); } @@ -2130,37 +3195,6 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( namespace { -// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the -// thread lives within a square tile of size tile_size (so thread blocks are of -// size tile_size * tile_size). -std::tuple CalculateYXCoordinateWithinTile( - llvm::IRBuilder<>* builder, llvm::Value* tile_size, - int64 threads_per_tile) { - // Calculate the starting element coordinate within a tile for the current - // thread, (y, x) from thread_id. - llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, threads_per_tile, - llvm::cast(thread_id)); - thread_id = builder->CreateIntCast(thread_id, tile_size->getType(), - /*isSigned=*/true, "thread.id.x"); - auto x = builder->CreateURem(thread_id, tile_size); - auto y = builder->CreateUDiv(thread_id, tile_size); - return std::make_tuple(y, x); -} - -// Reads block_idx.x, casts it to type index_ty, and adds the assumption that -// it's in the range [0, num_blocks]. -llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty, - int64 num_blocks) { - llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, num_blocks, - llvm::cast(block_id)); - return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true, - "block.id.x"); -} - void EmitFullTile(const KernelMappingScheme* mapping_scheme, const IrArray::Index& tile_origin_index, llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, @@ -2208,8 +3242,7 @@ void EmitPartialTile( builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); ksl->IfReturnVoid( - loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width), - [&] { + "x_in_tile", builder->CreateICmpULT(x_loc, tile_width), [&] { // tile_height_bound = // ceil(tile_height / num_threads_y) * num_threads_y llvm::Value* ceiling_of_ratio = builder->CreateUDiv( @@ -2226,8 +3259,8 @@ void EmitPartialTile( [&](llvm::Value* y_indvar) { llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); ksl->IfReturnVoid( - loop_name + "_y_in_tile", - builder->CreateICmpULT(y_loc, tile_height), [&] { + "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), + [&] { emit_elem_function( source_idx.AddOffsetToDim( y_indvar, KernelMappingScheme::DimY, builder), @@ -2258,7 +3291,7 @@ void EmitTiledElementalCodeWithBoundsCheck( llvm::Type* index_ty = tile_width->getType(); ksl->IfReturnVoid( - loop_name + "_full_tile", + "full_tile", builder->CreateAnd( builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x), tile_width), @@ -2349,395 +3382,7 @@ void IrEmitterUnnested::EmitTileElementForFusion( } } -// Information to support the code generation for a tiled reduction kernel. -using AddressVector = InlinedVector; -class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { - public: - explicit ReductionCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme, - bool is_row_reduction) - : KernelCodegenInfo(mapping_scheme), - current_output_linear_index_address_(nullptr), - current_output_inbound_address_(nullptr), - is_row_reduction_(is_row_reduction) {} - - void SetCurrentOutputLinearIndexAddress(llvm::AllocaInst* a) { - current_output_linear_index_address_ = a; - } - // Returns the address of the memory that stores the linear index of the - // current output. Since we are processing reduction to contiguous physical - // dimensions, this linear index is the linear index of the 1D output array. - llvm::AllocaInst* GetCurrentOutputLinearIndexAddress() const { - return current_output_linear_index_address_; - } - - void SetCurrentOutputInboundAddress(llvm::AllocaInst* a) { - current_output_inbound_address_ = a; - } - - llvm::AllocaInst* GetCurrentOutputInboundAddress() const { - return current_output_inbound_address_; - } - - AddressVector* GetMutablePartialResultAddresses() { - return &partial_result_addresses_; - } - const AddressVector& GetPartialResultAddresses() const { - return partial_result_addresses_; - } - - AddressVector* GetMutableReductionInputAddresses() { - return &reduction_input_addresses_; - } - const AddressVector& GetReductionInputAddresses() const { - return reduction_input_addresses_; - } - - InlinedVector* GetMutableReducers() { return &reducers_; } - const InlinedVector& GetReducers() const { - return reducers_; - } - int GetNumberOfReduces() const { return reducers_.size(); } - - InlinedVector* GetMutableReductionOutputShapeIndices() { - return &reduction_output_shape_indices_; - } - const InlinedVector& GetReductionOutputShapeIndices() const { - return reduction_output_shape_indices_; - } - - bool IsRowReduction() const { return is_row_reduction_; } - - // Return the dimension that is being reduced between DimX and DimY. - int GetReducedDimensionEnum() const { - return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimX - : llvm_ir::KernelMappingScheme::DimY; - } - - // Return the dimension that is being ketp between DimX and DimY. - int GetKeptDimensionEnum() const { - return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimY - : llvm_ir::KernelMappingScheme::DimX; - } - - private: - AddressVector partial_result_addresses_; - AddressVector reduction_input_addresses_; - InlinedVector reducers_; - InlinedVector reduction_output_shape_indices_; - llvm::AllocaInst* current_output_linear_index_address_; - llvm::AllocaInst* current_output_inbound_address_; - bool is_row_reduction_; -}; - -namespace { -// Returns a group of instructions that generate the output for the kernel -// containing the given HLO instruction. The result may be an unnested kReduce -// HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple -// for a multiple output fusion. -absl::Span GetOutputInstructions( - HloInstruction* const* reduce_or_tuple_pointer) { - HloOpcode opcode = (*reduce_or_tuple_pointer)->opcode(); - CHECK(opcode == HloOpcode::kReduce || opcode == HloOpcode::kTuple); - return opcode == HloOpcode::kTuple - ? (*reduce_or_tuple_pointer)->operands() - : absl::Span(reduce_or_tuple_pointer, 1); -} - -const HloInstruction* GetFirstReduceInstruction( - absl::Span instructions) { - auto first_reduce_iter = - absl::c_find_if(instructions, [](const HloInstruction* inst) { - return inst->opcode() == HloOpcode::kReduce; - }); - CHECK_NE(first_reduce_iter, instructions.end()); - return *first_reduce_iter; -} - -}; // namespace - -void IrEmitterUnnested::EmitPrologueForOneReduction( - HloInstruction* unnested_hlo, HloInstruction* reduce_inst, int reduce_idx, - KernelCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter, - ShapeIndex output_shape_index) { - ReductionCodegenInfo* reduction_info = - static_cast(kernel_info); - - InlinedVector* reducers = - reduction_info->GetMutableReducers(); - CHECK(IsReductionToVector(*reduce_inst)); - reducers->push_back(reduce_inst->to_apply()); - - InlinedVector* reduction_output_shape_indices = - reduction_info->GetMutableReductionOutputShapeIndices(); - reduction_output_shape_indices->push_back(std::move(output_shape_index)); - - AddressVector* reduction_input_addresses = - reduction_info->GetMutableReductionInputAddresses(); - llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( - reduce_inst->shape().element_type(), ir_emitter_context_->llvm_module()); - llvm::AllocaInst* reduction_input_address = Alloca(element_type); - reduction_input_addresses->push_back(reduction_input_address); - - AddressVector* partial_result_addresses = - reduction_info->GetMutablePartialResultAddresses(); - llvm::AllocaInst* partial_result_address = - Alloca(element_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(reduce_idx)); - partial_result_addresses->push_back(partial_result_address); - - // Initialize the partial result with the initial value of the reduction. - llvm::Value* init_ir_value; - if (unnested_hlo->opcode() == HloOpcode::kFusion) { - HloInstruction* init_value_operand = reduce_inst->mutable_operand(1); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), - elemental_emitter); - - TF_CHECK_OK(init_value_operand->Accept(&fused_emitter)); - init_ir_value = - fused_emitter - .GetGenerator(init_value_operand)(IrArray::Index(b_.getInt32Ty())) - .ValueOrDie(); - } else { - const HloInstruction* init_value = unnested_hlo->operand(1); - init_ir_value = - GetIrArray(*init_value, *unnested_hlo) - .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_); - } - - Store(init_ir_value, partial_result_address); -} - -void IrEmitterUnnested::EmitPrologueForReduction( - HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { - VLOG(10) << "Emit prologue for reduction " << unnested_hlo->ToString(); - // Find the unnested kReduce or the tuple that contains a list of kReduce. - HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion - ? unnested_hlo->fused_expression_root() - : unnested_hlo; - absl::Span output_instructions = - GetOutputInstructions(&reduce_or_tuple); - ReductionCodegenInfo* reduction_info = - static_cast(kernel_info); - GpuElementalIrEmitter elemental_emitter(hlo_module_config_, - ir_emitter_context_->llvm_module(), - &b_, GetNestedComputer()); - const HloInstruction* first_reduce = nullptr; - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - if (output_instructions[i]->opcode() != HloOpcode::kReduce) { - continue; - } - HloInstruction* reduce_inst = output_instructions[i]; - if (first_reduce == nullptr) { - first_reduce = reduce_inst; - } else { - CHECK(first_reduce->dimensions() == reduce_inst->dimensions()); - } - ShapeIndex output_shape_index; - if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { - output_shape_index = {i}; - } - - EmitPrologueForOneReduction(unnested_hlo, reduce_inst, i, kernel_info, - &elemental_emitter, - std::move(output_shape_index)); - } - - // Allocate stack storage to store the current output linear index and record - // the address of the storage. - reduction_info->SetCurrentOutputLinearIndexAddress( - Alloca(reduction_info->GetIndexType())); - - if (!reduction_info->IsRowReduction()) { - llvm::Type* bool_ty = b_.getInt1Ty(); - llvm::AllocaInst* output_inbound_addr = Alloca(bool_ty); - Store(llvm::ConstantInt::get(bool_ty, 0), output_inbound_addr); - reduction_info->SetCurrentOutputInboundAddress(output_inbound_addr); - } -} - -void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces( - const InlinedVector& reducers, - const AddressVector& partial_result_addresses) { - for (int distance = 16; distance >= 1; distance /= 2) { - for (int i = 0; i != reducers.size(); ++i) { - llvm::Type* element_type = - partial_result_addresses[i]->getType()->getElementType(); - int bit_width = llvm_ir::GetSizeInBits(element_type); - llvm::Value* result_from_other_lane = Alloca( - element_type, nullptr, "result_from_other_lane" + llvm::Twine(i)); - // Bitcast cannot be applied to aggregate types (even packed ones), so - // we bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffled_value_type = - element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type; - auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { - return BitCast(ptr, shuffled_value_type->getPointerTo()); - }; - llvm::Value* partial_result = - Load(convert_pointer_for_shuffle(partial_result_addresses[i]), - "partial_reduction_result"); - Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_), - convert_pointer_for_shuffle(result_from_other_lane)); - TF_CHECK_OK(EmitCallToNestedComputation( - *reducers[i], {partial_result_addresses[i], result_from_other_lane}, - partial_result_addresses[i])); - } - } -} - -void IrEmitterUnnested::EmitEpilogueForReduction( - HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { - ReductionCodegenInfo* reduction_info = - static_cast(kernel_info); - int num_reduces = reduction_info->GetNumberOfReduces(); - const AddressVector& partial_result_addresses = - reduction_info->GetPartialResultAddresses(); - const InlinedVector& reducers = - reduction_info->GetReducers(); - const InlinedVector& reduction_output_shape_indices = - reduction_info->GetReductionOutputShapeIndices(); - - if (reduction_info->IsRowReduction()) { - EmitFullWarpShuffleDownLoopForAllReduces(reducers, - partial_result_addresses); - llvm::Value* lane_id = reduction_info->GetLaneId(); - llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ICmpEQ(lane_id, llvm::ConstantInt::get(lane_id->getType(), 0)), - "lane_id_is_zero", &b_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); - } else { - llvm::Value* output_inbound_addr = - reduction_info->GetCurrentOutputInboundAddress(); - llvm::Value* output_inbound = Load(output_inbound_addr); - llvm_ir::LlvmIfData if_output_inbound_data = llvm_ir::EmitIfThenElse( - ICmpEQ(output_inbound, - llvm::ConstantInt::get(output_inbound->getType(), 1)), - "output_inbound", &b_); - llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_); - } - - // Emit an atomic operation that accumulates the partial reduction to the - // output element. For row reduction, this is only for lane 0 due to the - // if-statement emitted above. - for (int i = 0; i != num_reduces; ++i) { - IrArray::Index element_index( - /*linear=*/Load(reduction_info->GetCurrentOutputLinearIndexAddress(), - "output_linear_addr"), - ShapeUtil::GetSubshape(unnested_hlo->shape(), - reduction_output_shape_indices[i]), - &b_); - llvm::Value* output_address = - GetIrArray(*unnested_hlo, *unnested_hlo, - reduction_output_shape_indices[i]) - .EmitArrayElementAddress(element_index, &b_, - "output_element_address"); - // Do not emit atomic operations if each element in the reduction result is - // computed by one block, that is the dimension being reduced has only one - // block. - const llvm_ir::KernelMappingScheme* mapping_scheme = - reduction_info->GetKernelMappingScheme(); - if (mapping_scheme->GetTileBlockSizeForDimension( - llvm_ir::KernelMappingScheme::DimZ) == 1 && - mapping_scheme->GetTileBlockSizeForDimension( - reduction_info->GetReducedDimensionEnum()) == 1) { - TF_CHECK_OK(EmitCallToNestedComputation( - *reducers[i], {output_address, partial_result_addresses[i]}, - output_address)); - } else { - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_result_addresses[i])); - } - } -} - -void IrEmitterUnnested::EmitTileElementForReduction( - HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { - VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); - HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion - ? unnested_hlo->fused_expression_root() - : unnested_hlo; - llvm_ir::TiledParameterInfo* tiled_param_info = - kernel_info->GetTiledParameterInfo(); - tiled_param_info->set_y(y_loc); - tiled_param_info->set_x(x_loc); - - // Record the linear address for the current reduction. - const ReductionCodegenInfo* reduction_info = - dynamic_cast(kernel_info); - Store(index[reduction_info->GetKeptDimensionEnum()], - reduction_info->GetCurrentOutputLinearIndexAddress()); - if (!reduction_info->IsRowReduction()) { - llvm::Type* bool_ty = b_.getInt1Ty(); - llvm::AllocaInst* output_inbound_addr = - reduction_info->GetCurrentOutputInboundAddress(); - Store(llvm::ConstantInt::get(bool_ty, 1), output_inbound_addr); - } - - InlinedVector input_gens; - std::vector> - extra_output_gens; - GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, - GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), - &elem_emitter); - absl::Span output_instructions = - GetOutputInstructions(&reduce_or_tuple); - // Construct the ElementGenerator for each reduction and extra output in the - // the group of output instructions. - if (unnested_hlo->opcode() == HloOpcode::kFusion) { - fused_emitter.SetTiledParameterInfo(tiled_param_info); - TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter)); - - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - const HloInstruction* inst = output_instructions[i]; - ShapeIndex output_shape_index; - if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { - output_shape_index = {i}; - } - if (inst->opcode() == HloOpcode::kReduce) { - input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); - } else { - extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), - std::move(output_shape_index)); - } - } - } else { - input_gens.push_back([&](const IrArray::Index& index) { - return GetIrArray(*unnested_hlo->operand(0), *unnested_hlo) - .EmitReadArrayElement(index, &b_); - }); - } - - IrArray::Index input_index = - reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( - index, - GetFirstReduceInstruction(output_instructions)->operand(0)->shape()); - const AddressVector& partial_reduction_result_addresses = - reduction_info->GetPartialResultAddresses(); - const AddressVector& reduction_input_addresses = - reduction_info->GetReductionInputAddresses(); - const InlinedVector& reducers = - reduction_info->GetReducers(); - - // Emit code to generate the input and perform the reduction computation for - // each reduction instruction. - for (int i = 0; i != reducers.size(); ++i) { - llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie(); - Store(input_ir_value, reduction_input_addresses[i]); - TF_CHECK_OK(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], reduction_input_addresses[i]}, - partial_reduction_result_addresses[i])); - } - - // Emit code to generate the output for the non-reduction instructions in the - // fusion, if any. - TF_CHECK_OK( - EmitExtraOutputsForReduce(unnested_hlo, input_index, extra_output_gens)); -} - -// Emits a kernel for the hlo instruction using the given tiling scheme. +// Emits a block of tiles, given a function object to emit one tile. void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, const KernelCodegenInfo* kernel_info, KernelSupportLibrary& ksl, @@ -2864,6 +3509,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( << llvm_ir::DumpToString(*param_shmem_buffers[id]); } + CHECK_EQ(mapping_scheme->GetThreadsPerTile() % kWarpSize, 0); LaunchDimensions launch_dimensions = LaunchDimensions( mapping_scheme->GetNumberOfBlocks(), mapping_scheme->GetThreadsPerTile()); llvm::Type* index_ty = GetIndexTypeForKernel( @@ -2872,6 +3518,21 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( return llvm::ConstantInt::get(index_ty, c); }; + // For multioutput fusion, one thread needs to output a tuple with pointers to + // all the individual outputs. We could do this at any point in the kernel, + // but we do it at the beginning in the hopes of reducing register pressure, + // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel + // *anyway*. + if (unnested_hlo->IsMultiOutputFusion()) { + TF_CHECK_OK(KernelSupportLibrary(&b_).If( + "emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), + ConstructIrArrayForOutputs(*unnested_hlo), &b_, + module_); + return Status::OK(); + })); + } + // For each tiled parameter, cast its input IrArray to the corresponding // reduced shape and keep the reduced shape live during IR emission. std::vector param_in_reduced_shape_arrays; @@ -2892,7 +3553,6 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( kernel_info->SetLaneId( mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x : nullptr); - kernel_info->SetIndexType(index_ty); KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. @@ -2917,31 +3577,29 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + // Note that tile_width and tile_height are flipped here because we are + // reading a transposed tile. + emit_tiled_elemental_code_with_bounds_check( + input_index, "input", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + // TODO(jlebar): Add AA metadata to this store. Tile buffers are + // global variables, so LLVM can't infer much about it. + Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); + } + }); + // If shared memory transpose is needed, wait for all threads to reach this // point, lest we copy a value from tile to output before the other thread // copies it from input to tile. This is `__syncthreads` in CUDA. if (!tiled_param_ids.empty()) { - // Copy input parameter values to shared memory buffers: - // tile[y, x] = input[index] - // Note that tile_width and tile_height are flipped here because we are - // reading a transposed tile. - emit_tiled_elemental_code_with_bounds_check( - input_index, "input", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - for (int64 id : tiled_param_ids) { - IrArray& input_in_logical_shape = - param_in_reduced_shape_arrays[id]; - llvm::Value* shmem_buffer = param_shmem_buffers[id]; - // TODO(jlebar): Add AA metadata to this store. Tile buffers are - // global variables, so LLVM can't infer much about it. - Store(input_in_logical_shape.EmitReadArrayElement( - index, &b_, "input_element"), - GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); - } - }); - - // Wait for all threads to reach this point using `__syncthreads` in CUDA. llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); } @@ -2961,7 +3619,6 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( kernel_generator.GetTileElementGenerator()(unnested_hlo, index, kernel_info, y_loc, x_loc); }); - // If a tile block contains multiple tiles and shared memory buffers are // used, we need to wait for all threads to finish using the shared memory // buffer for the current tile before we move on to process the next tile @@ -2985,15 +3642,6 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( block_epilogue_generator(unnested_hlo, kernel_info); } - // For multioutput fusion, emit a tuple with pointers to all the individual - // outputs. - if (unnested_hlo->IsMultiOutputFusion()) { - std::vector output_arrays = - ConstructIrArrayForOutputs(*unnested_hlo); - llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), output_arrays, - &b_, module_); - } - return launch_dimensions; } @@ -3166,246 +3814,6 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return true; } -namespace { -// Checks that the outputs of a fusion with reduction are consistent. -Status AreFusedReductionOutputsConsistent( - absl::Span output_instructions, - const HloInstruction* first_reduce) { - for (const HloInstruction* inst : output_instructions) { - if (inst->opcode() == HloOpcode::kReduce) { - // Shapes, layouts and dimensions must be the same for all reduces - // inside of this fusion. - TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); - TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), - inst->operand(0)->shape())); - TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), - inst->operand(1)->shape())); - TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions()); - } else { - // For extra outputs we can relax shape equality to allow different - // types (with the same number of elements). Layouts still have to - // match. - TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType( - first_reduce->operand(0)->shape(), inst->shape())); - TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), - inst->shape().layout())); - } - } - return Status::OK(); -} - -// Finds the dimensions to keep for the reduction, sorts and returns the -// dimensions from minor to major. -DimensionVector GetDimensionsToKeepMinorToMajor( - const Shape& input_shape, absl::Span dims_to_reduce) { - DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); - absl::c_iota(input_dims, 0); - DimensionVector input_dims_to_keep; - for (int input_dim : input_dims) { - auto it = absl::c_find_if(dims_to_reduce, [&](int64 dim_to_reduce) { - return dim_to_reduce == input_dim; - }); - if (it == dims_to_reduce.end()) { - input_dims_to_keep.push_back(input_dim); - } - } - - // Sort the dimensions to keep from minor to major. - absl::c_sort(input_dims_to_keep, [&input_shape](int64 dim_a, int64 dim_b) { - return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) < - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b); - }); - - VLOG(10) << "dims to keep minor to major" - << absl::StrJoin(input_dims_to_keep, ","); - return input_dims_to_keep; -} - -// Given the input shape and dimensions to reduce for the reduction to vector, -// returns : -// num_kept: the number of elements in the contiguous dimensions to keep. -// num_reduced_major: the number of elements in the dimensions to reduce that -// are more major than the dimensions to keep. -// num_reduced_minor: the number of elements in the dimensions to reduce that -// are more minor than the dimensions to kept. -std::tuple GetReductionToVectorDimensions( - const Shape& input_shape, absl::Span dims_to_reduce) { - DimensionVector input_dims_to_keep_minor_to_major = - GetDimensionsToKeepMinorToMajor(input_shape, dims_to_reduce); - CHECK(LayoutUtil::AreDimensionsConsecutive( - input_shape.layout(), input_dims_to_keep_minor_to_major)); - int num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1; - if (input_dims_to_keep_minor_to_major.empty()) { - return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); - } - DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); - absl::c_iota(input_dims, 0); - absl::Span minor_to_major = - LayoutUtil::MinorToMajor(input_shape); - for (int input_dim : input_dims) { - int64 curr_dim_size = input_shape.dimensions(input_dim); - if (PositionInContainer(minor_to_major, input_dim) > - PositionInContainer(minor_to_major, - input_dims_to_keep_minor_to_major.back())) { - num_reduced_major *= curr_dim_size; - } else if (PositionInContainer(minor_to_major, input_dim) < - PositionInContainer(minor_to_major, - input_dims_to_keep_minor_to_major.front())) { - num_reduced_minor *= curr_dim_size; - } else { - num_kept *= curr_dim_size; - } - } - - return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); -} - -std::tuple ComputeMappingSchemeAndReductionKind( - const HloInstruction* first_reduce, llvm::IRBuilder<>* b) { - int64 depth = 1; - int64 height = 1; - int64 width = 1; - bool is_row_reduction = true; - int64 tile_size_x = 1; - int64 tile_size_y = 1; - int64 block_size_y = 1; - int64 block_size_z = 1; - int64 num_threads_x = 1; - int64 num_threads_y = 1; - const Shape& input_shape = first_reduce->operand(0)->shape(); - int64 num_input_elems = ShapeUtil::ElementsIn(input_shape); - int64 num_output_elems = ShapeUtil::ElementsIn(first_reduce->shape()); - int64 num_reduced_major, num_kept, num_reduced_minor; - std::tie(num_reduced_major, num_kept, num_reduced_minor) = - GetReductionToVectorDimensions(input_shape, first_reduce->dimensions()); - CHECK_EQ(num_output_elems, num_kept); - - if (num_kept == 1) { - // Scalar reduction is a special row reduction with depth = height = 1. - width = num_input_elems; - tile_size_x = kWarpSize * 16; - num_threads_x = kWarpSize; - } else if (num_reduced_minor == 1) { - // Column reduction reduces inputs with dimension [height, width], where - // width is the minor dimension, to dimension [width]. - height = num_reduced_major; - width = num_kept; - is_row_reduction = false; - tile_size_x = std::min(kWarpSize, num_kept); - // The old Column reduction algorithm uses kTileHeight = 128. We choose - // tile_size_y * block_size_y = 128 to match the value of kTileHeight. Using - // a non-trivial block_size_y here is a way to avoid unrolling all the 128 - // iterations. - tile_size_y = 32; - block_size_y = 4; - num_threads_x = tile_size_x; - } else { - // Row reduction reduces inputs with dimension [depth, height, width], - // where width is the most minor dimension, to dimension [height] . - depth = num_reduced_major; - height = num_kept; - width = num_reduced_minor; - num_threads_x = kWarpSize; - if (width % (kWarpSize * 64) == 0) { - tile_size_x = kWarpSize * 64; - } else { - tile_size_x = kWarpSize * 8; - block_size_z = 8; - while (depth % block_size_z != 0) { - block_size_z -= 1; - } - } - } - DCHECK_EQ(depth * height * width, num_input_elems); - VLOG(10) << "is_row_reduction " << is_row_reduction << depth << " " << height - << " " << width; - - DimensionVector dims_in_elem{depth, height, width}; - DimensionVector req_block_sizes{block_size_z, block_size_y, 1}; - llvm_ir::KernelMappingScheme mapping_scheme(dims_in_elem, tile_size_y, - tile_size_x, req_block_sizes, - num_threads_y, num_threads_x, b); - return std::make_tuple(mapping_scheme, is_row_reduction); -} - -} // namespace - -Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { - VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); - - HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion - ? unnested_hlo->fused_expression_root() - : unnested_hlo; - absl::Span output_instructions = - GetOutputInstructions(&reduce_or_tuple); - const HloInstruction* first_reduce = - GetFirstReduceInstruction(output_instructions); - - if (output_instructions.size() > 1) { - TF_RETURN_IF_ERROR( - AreFusedReductionOutputsConsistent(output_instructions, first_reduce)); - } - - // Build an initializer thunk to initialize each reduction output. - std::vector> thunks; - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - if (output_instructions[i]->opcode() != HloOpcode::kReduce) { - continue; - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr initializer_thunk, - BuildInitializerThunk(unnested_hlo, - (output_instructions[i] == reduce_or_tuple) - ? ShapeIndex() - : ShapeIndex({i}))); - thunks.push_back(std::move(initializer_thunk)); - } - - // Build a kernel thunk to compute all the outputs. - std::unique_ptr kernel_thunk = - BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); - - const Shape& input_shape = first_reduce->operand(0)->shape(); - // The layout of a reduction input is either set by LayoutAssignment for - // unnested kReduce or by InstructionFusion for fused kReduce. - CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " - "doesn't set the input layout of " - << first_reduce->ToString(); - - bool is_row_reduction; - llvm_ir::KernelMappingScheme mapping_scheme; - std::tie(mapping_scheme, is_row_reduction) = - ComputeMappingSchemeAndReductionKind(first_reduce, &b_); - ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction); - KernelCodeGenerator kernel_generator( - /*tile_element_generator=*/ - [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { - EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc); - }, - /*block_prologue_generator=*/ - [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { - EmitPrologueForReduction(hlo, kernel_info); - }, - /*block_epilogue_generator*/ - [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { - EmitEpilogueForReduction(hlo, kernel_info); - }); - - LaunchDimensions launch_dimensions = - EmitKernel(unnested_hlo, {}, kernel_generator, &reduction_info); - UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), - ir_emitter_context_->llvm_module()); - - thunks.push_back(std::move(kernel_thunk)); - std::unique_ptr sequential_thunk = - absl::make_unique(std::move(thunks), unnested_hlo); - AddThunkToThunkSequence(std::move(sequential_thunk)); - - return Status::OK(); -} - Status IrEmitterUnnested::EmitConstantGlobals() { for (const BufferAllocation& allocation : ir_emitter_context_->buffer_assignment().Allocations()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 85a0e5328c4e436d4522593b38421efc87c42d32..e09ed657a812be6ab4859a0e365a51c45a37bfed 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ -#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" @@ -69,12 +68,9 @@ class IrEmitterUnnested : public IrEmitter { explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme) : mapping_scheme_(mapping_scheme), tiled_param_info_(nullptr), - lane_id_(nullptr), - index_ty_(nullptr) {} - virtual ~KernelCodegenInfo() {} + lane_id_(nullptr) {} void SetLaneId(llvm::Value* v) { lane_id_ = v; } - void SetIndexType(llvm::Type* t) { index_ty_ = t; } void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { CHECK_EQ(tiled_param_info_, nullptr); tiled_param_info_ = tiled_param_info; @@ -87,13 +83,11 @@ class IrEmitterUnnested : public IrEmitter { llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const { return tiled_param_info_; } - llvm::Type* GetIndexType() const { return index_ty_; } private: llvm_ir::KernelMappingScheme* mapping_scheme_; llvm_ir::TiledParameterInfo* tiled_param_info_; llvm::Value* lane_id_; - llvm::Type* index_ty_; }; // A function object to prepare for the code generation for a tile block. @@ -206,14 +200,82 @@ class IrEmitterUnnested : public IrEmitter { // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( - const HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, + const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, absl::Span> extra_output_gens); - // Generates code for reduction to contiguous dimensions. + // EmitColumnReduction and EmitRowReduction emit code for column and row + // reduction of a matrix and/or 3D tensor. Row and column reduction have + // different memory access pattern, so for performance their implementations + // are significantly different. // - // Prerequisite: `IsReductionToVector(*unnested_hlo)` - Status EmitReductionToVector(HloInstruction* unnested_hlo); + // Emits code that reduces a matrix of shape [height x width] to a vector of + // [width]. Other parameters have the same meaning as those of + // `EmitReductionToVector`. Note that input shape might not be + // [height x width], but can be bitcast to [height x width] with "height" + // being the major dimension. + Status EmitColumnReduction( + KernelThunk* kernel_thunk, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> + extra_output_gens); + + // Emits code that reduces a 3D tensor of shape [depth x height x width] to a + // vector of shape [height]. Other parameters have the same meaning as those + // of `EmitReductionToVector`. Note that input shape might not be + // [depth x height x width], but can be bitcast to [depth x height x width] + // with "depth" being the most major dimension. + Status EmitRowReduction( + KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> + extra_output_gens); + + // Emits code that reduces a tensor of arbitrary rank to a scalar. + Status EmitReductionToScalar( + KernelThunk* kernel_thunk, HloInstruction* reduce, + const Shape& input_shape, + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> + extra_output_gens); + + // Figures out whether `reduce` is a row or column reduction, and which + // dimensions to reduce, and calls either `EmitRowReduction` or + // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the + // input array, which is the operand of the Reduce instruction if unfused or + // of the Fusion instruction if fused. `input_gen` and `init_value_gen` + // generate elements of the input and the initial value. Other parameters mean + // the same as for `HandleReduce`. + // + // Multiple reduces can be emitted in the same loop, assuming they have the + // same input and output shapes, and the same reduce dimensions. + // + // extra_output_gens can contain extra generators for intermediate outputs. + // These must have the same shape as the reduce input as they are computed + // when the reduce inputs are being read. + // + // Prerequisite: `IsReductionToVector(*reduce)` + Status EmitReductionToVector( + KernelThunk* kernel_thunk, HloInstruction* reduce, + const Shape& input_shape, + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span dimensions_to_reduce, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> + extra_output_gens); // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. `scatter` may be fused, scatter indices are taken from @@ -252,29 +314,6 @@ class IrEmitterUnnested : public IrEmitter { const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, llvm::Value* x_loc); - // Emits code to process a tensor element in a tile for the given input hlo - // that is either a unnested kReduce or a kInput fusion. - void EmitTileElementForReduction(HloInstruction* unnested_hlo, - const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); - // Prepares for the code generation for a tile block of a reduction kernel. - void EmitPrologueForReduction(HloInstruction* unnested_hlo, - KernelCodegenInfo* kernel_info); - void EmitPrologueForOneReduction(HloInstruction* unnested_hlo, - HloInstruction* reduce_inst, int reduce_idx, - KernelCodegenInfo* kernel_info, - GpuElementalIrEmitter* elemental_emitter, - ShapeIndex output_shape_index); - // Wraps up the code generation for a tile block of a reduction kernel. - void EmitEpilogueForReduction(HloInstruction* unnested_hlo, - KernelCodegenInfo* kernel_info); - // For each reducer, emits the shuffle-down loop to accumulate the partial - // result to the global result. - void EmitFullWarpShuffleDownLoopForAllReduces( - const absl::InlinedVector& reducers, - const absl::InlinedVector& - partial_result_addresses); // Generates the IrArray for each input of an hlo and returns a vector that // constains such IrArrays. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index 24f07e68973a5b374976bf2a08f63697368cad50..bd53b90b42d8e657a3ee58e7ca03fb60522aae28 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -199,8 +199,7 @@ std::unique_ptr GetTargetMachine( } return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, - Optional(RelocModel), Optional(CMModel), - codegen_opt_level)); + getRelocModel(), getCodeModel(), codegen_opt_level)); } // Adds the standard LLVM optimization passes, based on the speed optimization @@ -394,8 +393,16 @@ StatusOr CompileModuleToPtx(llvm::Module* module, int32 opt_level = hlo_module_config.debug_options().xla_backend_optimization_level(); - CHECK_GE(opt_level, 2) - << "The XLA GPU backend doesn't support unoptimized code generation"; + if (opt_level < 2) { + LOG(ERROR) << std::string(80, '*'); + LOG(ERROR) << "The XLA GPU backend doesn't support unoptimized code " + "generation but "; + LOG(ERROR) << "--xla_backend_optimization_level is set to " << opt_level + << "!"; + LOG(ERROR) << "(Supported configuration is " + "--xla_backend_optimization_level >= 2.)"; + LOG(ERROR) << std::string(80, '*'); + } AddOptimizationPasses(opt_level, /*size_level=*/0, target_machine.get(), &module_passes, diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index e934cbda1765cb10b4ff2ac14c3ff2f7a5f5cc41..f3e17d888242a36c268dcbfa0d6530f80cedceb0 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -479,7 +479,8 @@ void WarnIfBadDriverJITVersion() { // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. StatusOr> CompilePtx(const string& ptx, int cc_major, - int cc_minor) { + int cc_minor, + bool disable_ptx_optimizations) { tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); @@ -519,6 +520,9 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } + if (disable_ptx_optimizations) { + ptxas_args.push_back("-O0"); + } ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); @@ -739,8 +743,9 @@ StatusOr> NVPTXCompiler::RunBackend( } } - const std::vector cubin = - CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); + const std::vector cubin = CompilePtxOrGetCachedResult( + ptx, cc_major, cc_minor, + module->config().debug_options().xla_gpu_disable_ptxas_optimizations()); auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), @@ -772,9 +777,9 @@ StatusOr> NVPTXCompiler::RunBackend( return std::unique_ptr(gpu_executable); } -std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, - int cc_minor) { +std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + bool disable_ptx_optimizations) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; @@ -802,8 +807,8 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, if (inserted) { CHECK(!cache_value->compilation_done); if (!ptx.empty()) { - StatusOr> maybe_cubin = - CompilePtx(*cache_ptx, cc_major, cc_minor); + StatusOr> maybe_cubin = CompilePtx( + *cache_ptx, cc_major, cc_minor, disable_ptx_optimizations); if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index f79ae2990ae7d6e6985b15727a72358289121aa9..be5e31a50112686841e6f18b76f382a56e61bafc 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -97,8 +97,9 @@ class NVPTXCompiler : public LLVMCompiler { // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin. If compilation was unsuccessful, returns an empty vector. - std::vector CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, int cc_minor); + std::vector CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + bool disable_ptx_optimizations); // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 8b50cfa9aed90091cfbedc1df902440ec9bf2a80..0361c87428f6e4c031d95492a5bc782ad388e5b5 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -20,19 +20,19 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { +namespace m = match; using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; @@ -261,7 +261,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); - EXPECT_THAT(copy, op::Copy(constant)); + EXPECT_THAT(copy, GmockMatch(m::Copy(m::Op().Is(constant)))); } TEST_F(HloComputationTest, DeepCopyTuple) { @@ -278,8 +278,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto computation = module->AddEntryComputation(builder.Build()); auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); - EXPECT_THAT(tuple_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::Copy(op::GetTupleElement(tuple)))); + EXPECT_THAT(tuple_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))), + m::Copy(m::GetTupleElement(m::Op().Is(tuple)))))); EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index()); EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index()); } @@ -297,7 +298,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { ShapeTree indices_to_copy(constant->shape(), /*init_value=*/true); EXPECT_THAT(computation->DeepCopyInstruction(constant, &indices_to_copy) .ValueOrDie(), - op::Copy(constant)); + GmockMatch(m::Copy(m::Op().Is(constant)))); } { @@ -330,10 +331,11 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::Copy(op::GetTupleElement(tuple)))); - EXPECT_THAT(deep_copy, op::Tuple(copies_added.element({0}), - copies_added.element({1}))); + EXPECT_THAT(deep_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))) + .Is(copies_added.element({0})), + m::Copy(m::GetTupleElement(m::Op().Is(tuple))) + .Is(copies_added.element({1}))))); } { @@ -346,8 +348,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::GetTupleElement(tuple), - op::GetTupleElement(tuple))); + EXPECT_THAT(deep_copy, + GmockMatch(m::Tuple(m::GetTupleElement(m::Op().Is(tuple)), + m::GetTupleElement(m::Op().Is(tuple))))); EXPECT_TRUE(copies_added.element({}) == nullptr); EXPECT_TRUE(copies_added.element({0}) == nullptr); EXPECT_TRUE(copies_added.element({1}) == nullptr); @@ -363,8 +366,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::GetTupleElement(tuple))); + EXPECT_THAT(deep_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))), + m::GetTupleElement(m::Op().Is(tuple))))); EXPECT_TRUE(copies_added.element({}) == nullptr); EXPECT_TRUE(copies_added.element({0}) != nullptr); EXPECT_TRUE(copies_added.element({1}) == nullptr); @@ -381,7 +385,7 @@ TEST_F(HloComputationTest, DeepCopyToken) { auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); // No copy should be added. - EXPECT_THAT(copy, op::AfterAll()); + EXPECT_THAT(copy, GmockMatch(m::AfterAll())); } TEST_F(HloComputationTest, DeepCopyTokenTuple) { @@ -399,8 +403,9 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { // Only the array (second tuple element) should be copied. The token is passed // through transparently. - EXPECT_THAT(copy, op::Tuple(op::GetTupleElement(tuple), - op::Copy(op::GetTupleElement(tuple)))); + EXPECT_THAT(copy, GmockMatch(m::Tuple( + m::GetTupleElement(m::Op().Is(tuple)), + m::Copy(m::GetTupleElement(m::Op().Is(tuple)))))); } TEST_F(HloComputationTest, CycleDetection) { @@ -443,13 +448,15 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Negate(m::Op().Is(constant)))); EXPECT_EQ(negate, computation->root_instruction()); ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add)); EXPECT_EQ(2, computation->instruction_count()); - EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Negate(m::Op().Is(constant)))); EXPECT_EQ(negate, computation->root_instruction()); } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 312b5d020c398feb7738d14a9cfa0928d5178948..51177f24f5ee702be96fc8b4530ed38a5798109f 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -113,7 +113,7 @@ void HloPassPipeline::MaybeDumpHlo(const HloModule& module, } const string message = - StrCat("after ", after_pass_name, ", before ", before_pass_name); + absl::StrCat("after ", after_pass_name, ", before ", before_pass_name); hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; VLOG(3) << module.entry_computation_layout().ToString(); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 2297edcbe1d167f0752423f76b795b3592e85c47..7559ed1bab84b21a4d51bc38db999900befcfad7 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -457,8 +457,13 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_ = computation; reachability_ = HloReachabilityMap::Build(computation_); - HloInstructionSet do_not_duplicate = - ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + HloInstructionSet do_not_duplicate; + // If we allow duplications, we need to compute which instructions we do not + // want to duplicate based on a global analysis of the graph. + if (may_duplicate_) { + do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + } auto fusion_queue = GetFusionQueue(computation_); // Instruction fusion effectively fuses edges in the computation graph @@ -566,8 +571,8 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { auto is_reachable = [&](const HloInstruction* a, const HloInstruction* b) { - // A consumer operand may have been multii-output fused into a parallel - // consumer and thus be missing from the oridinal reachability map. + // A consumer operand may have been multi-output fused into a parallel + // consumer and thus be missing from the original reachability map. if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { reachability_ = HloReachabilityMap::Build(consumer->parent()); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 6b483126499fe1e635a7d13cf597ec5d089c5b24..58b7135cea7419f13d60ed510ecf7a88126aee48 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -394,6 +394,56 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + +TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + TEST_F(InstructionFusionTest, WideningConvertsAreAlwaysDuplicableIntoConsumers) { auto module = ParseHloString(R"( diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 7635fbfed6f6a51fc9d203251d9bebf43cc63fd9..de9204011ce5ba8a9fc2871c6bd7120b6ed371b5 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -85,6 +85,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( Literal result_literal; { tensorflow::mutex_lock lock(evaluator_lock_); + evaluator_->ResetVisitStates(); TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( *computation, arg_literals)); } diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 311bd7890545b5b2cbec920d2d12ddd482d0d53c..5c661bfacb08fe27f3cbdc1fb9db083315166008 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index 1aa85eb8d2d206bf0537deb659e779b24fffbb0a..c26711e526c9b89cdedcb6aed9f93d41dd25dc83 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -120,7 +120,7 @@ KernelMappingScheme::KernelMappingScheme( absl::Span req_block_sizes, int64 num_threads_y, int64 num_threads_x, llvm::IRBuilder<>* b) : b_(b), - dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()), + dims_in_elems_(dims_in_elems), tile_sizes_{1, tile_size_y, tile_size_x}, num_threads_x_(num_threads_x), num_threads_y_(num_threads_y) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 7277aeac8ad2086a2f6419b1fdb60c4872841adc..06002d57b0d7daa07f903feebe67a60a083c0e7c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -90,16 +90,15 @@ class KernelMappingScheme { enum { DimZ = 0, DimY, DimX, DimTot }; public: - KernelMappingScheme() {} // dims_in_elems: the normalized tensor dimensions. // req_block_sizes: the requested block size in number of tiles for each // dimension. The actual block size is set to min(req_block_size, // dims_in_number_of_blocks). - KernelMappingScheme(absl::Span dims_in_elems, int64 tile_size_y, - int64 tile_size_x, - absl::Span req_block_sizes, - int64 num_threads_y, int64 num_threads_x, - llvm::IRBuilder<>* b); + explicit KernelMappingScheme(absl::Span dims_in_elems, + int64 tile_size_y, int64 tile_size_x, + absl::Span req_block_sizes, + int64 num_threads_y, int64 num_threads_x, + llvm::IRBuilder<>* b); absl::Span GetDimensionsInElements() const { return dims_in_elems_; @@ -134,10 +133,6 @@ class KernelMappingScheme { } absl::Span GetBlockSizes() const { return block_sizes_; } - int64 GetTileBlockSizeForDimension(int d) const { - DCHECK(d >= DimZ && d <= DimX); - return dims_in_blocks_[d]; - } int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; } @@ -168,7 +163,7 @@ class KernelMappingScheme { private: llvm::IRBuilder<>* b_; // The number of elements in each dimension. - std::vector dims_in_elems_; + absl::Span dims_in_elems_; // The number of elements for each dimension of a tile. std::vector tile_sizes_; diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index cc06cf6ec76930da8e326a0744e085f39c21b5f2..c35f72699bfe90f7b8021916c0f81d5e1926ff4c 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -64,6 +64,9 @@ namespace xla { // e.g. IsConstantScalar() or IsConstantScalar(42). // - WithFusionKind // - WithTupleIndex: get-tuple-element operations with the given tuple index +// - WithOneUse: Instruction is used as an operand exactly once. +// - WithOneUser: Instruction is used by exactly one other instruction, but +// is possibly used more than once as an operand (e.g. multiply(x,x)). // // Shape(): // - EqualTo @@ -1133,6 +1136,13 @@ inline const HloInstruction* HloOperand(const HloInstruction* instr, return instr->operand(idx); } +// Pretty-printer for HloInstruction. Sort of like ToShortString, but with +// fewer %s and more shapes. +inline string InstToString(const HloInstruction* inst) { + return inst->ToString( + HloPrintOptions().set_print_metadata(false).set_print_percent(false)); +} + template class HloInstructionPattern; @@ -1187,14 +1197,14 @@ class HloInstructionIsImpl { bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { if (inst != inst_) { EXPLAIN << "HloInstruction " << inst << " is not " << inst_ << " (" - << inst_->ToShortString() << ")"; + << InstToString(inst_) << ")"; return false; } return true; } void DescribeTo(std::ostream* os, int64 indent = 0) const { - *os << "which is " << inst_ << " (" << inst_->ToShortString() << ")"; + *os << "which is " << inst_ << " (" << InstToString(inst_) << ")"; } private: @@ -1603,6 +1613,64 @@ class HloInstructionPatternParameterNumImpl { int64 parameter_num_; }; +// Superclass that contains common code used by Op::WithOneUse() and +// Op::WithOneUser(). +class HloInstructionPatternOneUseOrUserImpl { + protected: + bool MatchOneUser(const HloInstruction* inst, MatchOption option) const { + if (inst->user_count() != 1) { + EXPLAIN << "HloInstruction has " << inst->user_count() + << " users, but expected exactly one."; + if (inst->user_count() > 1) { + EXPLAIN << "\nAll users:"; + for (const HloInstruction* user : inst->users()) { + EXPLAIN << "\n - " << InstToString(user); + } + } + return false; + } + return true; + } +}; + +class HloInstructionPatternOneUseImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + if (!MatchOneUser(inst, option)) { + return false; + } + + int64 use_count = absl::c_count_if( + inst->users()[0]->operands(), + [&](const HloInstruction* operand) { return operand == inst; }); + if (use_count != 1) { + EXPLAIN << "HloInstruction is used " << use_count + << " times by its user, but is expected to be used just once: " + << InstToString(inst->users()[0]); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one use"; + } +}; + +class HloInstructionPatternOneUserImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + return MatchOneUser(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one user (but possibly is used multiple times by " + "that instruction)"; + } +}; + // Matches a constant scalar or effective scalar, optionally with a given value. template class HloConstantScalarImpl { @@ -1706,10 +1774,7 @@ class HloInstructionPattern { return true; } if (inst != nullptr) { - EXPLAIN << "\nin " - << inst->ToString(HloPrintOptions() - .set_print_metadata(false) - .set_print_percent(false)); + EXPLAIN << "\nin " << InstToString(inst); } return false; } @@ -1722,10 +1787,7 @@ class HloInstructionPattern { } return true; } - EXPLAIN << "\nin " - << inst->ToString(HloPrintOptions() - .set_print_metadata(false) - .set_print_percent(false)); + EXPLAIN << "\nin " << InstToString(inst); return false; } @@ -1877,6 +1939,22 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num)); } + // Modifies the pattern to match if the instruction is used exactly once. + // Does not match if the instruction is used twice by the same user (e.g. + // multiply(x,x)). + constexpr auto WithOneUse() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) { + return AppendImpl(HloInstructionPatternOneUseImpl()); + } + + // Modifies the pattern to match if the instruction is used by exactly one + // other instruction. Will match if the instruction is used twice, so long as + // it's by the same user (e.g. multiply(x,x)). + constexpr auto WithOneUser() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) { + return AppendImpl(HloInstructionPatternOneUserImpl()); + } + void DescribeTo(std::ostream* os, int64 indent = 0) const { impl_.DescribeTo(os, indent); } @@ -2154,6 +2232,7 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, // We could implement all ops as "variadic" ops, but it would make the // already-bad compile errors even worse. +XLA_VARIADIC_OP_PATTERN(AfterAll); XLA_VARIADIC_OP_PATTERN(Concatenate); XLA_VARIADIC_OP_PATTERN(CustomCall); XLA_VARIADIC_OP_PATTERN(Map) diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 13886fa6f5b7b55283e6e420734a22312987d8a6..186ef0c7911a2724df810780e018f52586e3e6a8 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -767,10 +767,11 @@ TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) { "in c = f64[] constant(2.25)"); EXPECT_DESC_AND_EXPLANATION( constant, m::Op().Is(iota.get()), - absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()), " (", - iota->ToShortString(), ")"), + absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)"), absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x", - absl::Hex(iota.get()), " (", iota->ToShortString(), ")\n", + absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)\n" "in c = s32[] constant(0)")); } @@ -875,5 +876,60 @@ TEST(PatternMatcherTest, Parameter) { "in p0 = f32[] parameter(0)"); } +TEST(PatternMatcherTest, OneUseAndOneUser) { + auto param = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUse(), + "an HloInstruction which has exactly one use", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUser(), + "an HloInstruction which has exactly one user (but possibly is used " + "multiple times by that instruction)", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + { + auto reshape = + SetName("r", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + + auto reshape1 = + SetName("r1", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + + const char* kMultipleUserExplanation = + "HloInstruction has 2 users, but expected exactly one.\n" + "All users:\n" + " - r = f32[1]{0} reshape(f32[] p0)\n" + " - r1 = f32[1]{0} reshape(f32[] p0)\n" + "in p0 = f32[] parameter(0)"; + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + kMultipleUserExplanation); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()), + kMultipleUserExplanation); + } + + auto add = SetName("add", HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, + param.get(), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + "HloInstruction is used 2 times by its user, but is expected to be " + "used just once: add = f32[] add(f32[] p0, f32[] p0)\n" + "in p0 = f32[] parameter(0)"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index f3cc51ca9158d5c355c656b5450da1a66d96a379..a4d4e1e53e727bdf7822cacaa4559fcae59d4eae 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -584,7 +584,7 @@ namespace { // Parses shapes with simple recursive descent structure -- consumes from the // front of s and passes that view recursively as required. StatusOr ParseShapeStringInternal(absl::string_view* s) { - *s = StripLeadingAsciiWhitespace(*s); + *s = absl::StripLeadingAsciiWhitespace(*s); if (absl::ConsumePrefix(s, "(")) { // Tuple. std::vector shapes; @@ -597,7 +597,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - *s = StripLeadingAsciiWhitespace(*s); + *s = absl::StripLeadingAsciiWhitespace(*s); must_end = !absl::ConsumePrefix(s, ","); } return ShapeUtil::MakeTupleShape(shapes); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 48bbc4e47285988935f3a6e1444a4c7918aaab7b..5a7a4faa7e89b27fb537f20d94c21cb4a76e000d 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -303,11 +303,6 @@ xla_test( name = "conv_depthwise_test", timeout = "long", srcs = ["conv_depthwise_test.cc"], - blacklisted_backends = [ - # disabled because of a break b/119590850. - "cpu", - "gpu", - ], shard_count = 50, deps = [ "//tensorflow/compiler/xla:execution_options_util", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 0615f9425c1289d666641f4d581946b44b4895ce..915b456b52215f8d6a9eb6c5b933f3502f1d3d2c 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -329,13 +329,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); - auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); - auto b_param = ConstantR1(&builder, b_values); + auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param"); + auto b_constant = ConstantR1(&builder, b_values); - auto sum1 = Add(a_constant, b_constant); - auto sum2 = Add(a_constant, b_param); - auto sum3 = Add(a_param, b_constant); - auto sum4 = Add(a_param, b_param); + auto sum1 = Add(a_constant, b_param); + auto sum2 = Add(a_constant, b_constant); + auto sum3 = Add(a_param, b_param); + auto sum4 = Add(a_param, b_constant); auto sum = Add(sum1, sum2); sum = Add(sum, sum3); @@ -350,6 +350,44 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { error_spec_); } +// TODO(b/119692968): This test runs OOM on the GPU and CPU backend. +XLA_TEST_F(ArrayElementwiseOpTest, + DISABLED_ON_GPU(DISABLED_ON_CPU(DeeplyNestedAddWithSlices))) { + XlaBuilder builder(TestName()); + std::vector values(30, 0.0); + auto a_literal = LiteralUtil::CreateR1(values); + auto a = Parameter(&builder, 0, a_literal.shape(), "x"); + auto b_literal = LiteralUtil::CreateR1(values); + auto b = Parameter(&builder, 1, b_literal.shape(), "x"); + + // Construct a sequence of diamond-shaped gadgets like this: + // + // add + // / \ + // slice slice + // \ / + // add + // + // Each 'left' slice removes the last element, each 'right' slice removes the + // first element. In this way, we index into the add with different + // multi-dimensional index arrays, which defeats the caching we use to avoid + // exponential compile time. + std::function generate_recursive = + [&](int64 slice_size) -> XlaOp { + if (slice_size == values.size()) { + return Add(a, b); + } + XlaOp param = generate_recursive(slice_size + 1); + auto slice1 = Slice(param, {0}, {slice_size}, {1}); + auto slice2 = Slice(param, {1}, {slice_size + 1}, {1}); + return Add(slice1, slice2); + }; + generate_recursive(1); + auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); + ComputeAndCompareR1(&builder, {0.0}, {a_data.get(), b_data.get()}); +} + XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc index 60ce576ceb20b89b59e72d821e63b0ccdee51b0b..627a17a0ca114085240dbaf28211bb3511cf0cab 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc @@ -50,9 +50,9 @@ class DepthwiseConvolution2DTest static std::vector GetConv2DTestCases() { std::vector config_set; std::vector> config_options = { - {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64}, - {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {128, 1, 2, 144}, - {256, 1, 2, 64}, {64, 14, 12, 172}, {16, 9, 4, 16}}; + {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64}, + {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {64, 14, 12, 172}, + {16, 9, 4, 16}, {128, 1, 2, 144}, {256, 1, 2, 64}}; for (auto option : config_options) { int64 feature = option[0]; @@ -136,7 +136,7 @@ string BuildHloTextDepthwiseConvolution2D( if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv, is_scheduled=true + HloModule TensorFlowDepthwiseConv ENTRY main { activation = %s[%s]{%s} parameter(0) @@ -161,7 +161,7 @@ string BuildHloTextDepthwiseConvolution2D( } else if (spec.stride == -1) { return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv, is_scheduled=true + HloModule TensorFlowDepthwiseConv ENTRY main { activation = %s[%s]{%s} parameter(0) @@ -185,7 +185,7 @@ string BuildHloTextDepthwiseConvolution2D( } else { return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv, is_scheduled=true + HloModule TensorFlowDepthwiseConv ENTRY main { activation = %s[%s]{%s} parameter(0) @@ -215,13 +215,13 @@ XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) { const string hlo_text = BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16); - EXPECT_TRUE(RunAndCompareNoHloPasses( - hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status { - BFloat16MixedPrecisionRemoval remover; - TF_RETURN_IF_ERROR(remover.Run(module).status()); - Despecializer despecializer; - return despecializer.Run(module).status(); - })); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); } INSTANTIATE_TEST_CASE_P( diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 25091b8d5d5498edf3ce86efe225cd0e2fd8ff6b..c5d8b663f4abe77e05ec213d2e4e075c260a8655 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index bdeb1728fa2321f25d9db230f2d449a7b4b348ee..a37eac7fe441d91aa71e1b6fd7b84099fee2215b 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -213,6 +213,9 @@ message DebugOptions { // the host that run models in parallel across multiple devices. int32 xla_force_host_platform_device_count = 102; + // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). + bool xla_gpu_disable_ptxas_optimizations = 103; + // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 8c6191ddc06ea7d85f5fd21a7d4058c669ffdeb2..751329eefc33f3372335c805233dafabbf42bf36 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -228,14 +228,35 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( shaped_buffer, device_ref.backend(), device_ref.device_ordinal(), &output_tuple)); - - Tensor* output_tensor; - TF_RETURN_IF_ERROR( - context->allocate_output(0, TensorShape({}), &output_tensor)); - int64 key; - TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); - output_tensor->scalar()() = key; - + if (config_proto.return_exploded_tuple() && + xla::ShapeUtil::IsTuple(output_tuple->on_device_shape())) { + int64 tuple_element_count = + xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape()); + Tensor* output_tensor; + TF_RETURN_IF_ERROR(context->allocate_output( + 0, TensorShape({tuple_element_count}), &output_tensor)); + + for (int64 i = 0; i < tuple_element_count; ++i) { + xla::ShapeIndex shape_index; + shape_index.push_back(i); + + XRTTupleAllocation* suballocation; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + output_tuple, shape_index, &suballocation, + /*alias_parent_allocation=*/false)); + int64 key; + TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key)); + output_tensor->vec()(i) = key; + } + output_tuple->Unref(); + } else { + Tensor* output_tensor; + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({}), &output_tensor)); + int64 key; + TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); + output_tensor->scalar()() = key; + } return Status::OK(); } diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index ffea592491d43788b876a51866dc8a6611e8c734..3258286c10665225aab917107ffa614459c53f3d 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -87,6 +87,19 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") .HostMemory("literal"), XRTReadLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_GPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_CPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); + REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") .Device(DEVICE_XLA_GPU) .HostMemory("handle") diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 54b06558adcd8ef1f8f1bee52d210d558801afea..26a58fa42d8b730b365b11d2e5608e9945497763 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -393,6 +393,56 @@ class XRTReadLiteralOp : public OpKernel { } }; +// Op that writes a new literal value into device-resident memory. +template +class XRTWriteLiteralOp : public OpKernel { + public: + explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~XRTWriteLiteralOp() override = default; + XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete; + XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTWriteLiteralOp::Compute"; + + const Tensor& handle_tensor = ctx->input(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), + errors::Internal("computation input should be an int64 scalar")); + int64 allocation_handle = handle_tensor.scalar()(); + + const Tensor& literal_info = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()), + errors::Internal("literal input should be a string scalar")); + xla::LiteralProto literal_proto; + OP_REQUIRES(ctx, + literal_proto.ParseFromString(literal_info.scalar()()), + errors::InvalidArgument( + "Unable to parse allocation input to LiteralProto")); + xla::Literal literal; + OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal)); + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK( + ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); + core::ScopedUnref allocation_unref(allocation); + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + typename DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( + ctx, allocation->device_ordinal(), &device_ref)); + OP_REQUIRES_OK(ctx, + allocation->WriteLiteral(device_ref.backend(), literal)); + + Tensor output(DT_INT64, TensorShape({})); + output.scalar()() = allocation_handle; + ctx->set_output(0, output); + } +}; + // Op that discards a handle to device memory. template class XRTReleaseAllocationOp : public OpKernel { diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 07d025ce343f229097b557d33ad41bf9612b0696..a3d63106fa14674a9f5887ccfd908ce17dbc6384 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -95,6 +95,20 @@ Copies an allocated tuple from device memory and returns it as a literal. 'literal' is a serialized xla::LiteralProto proto. )"); +REGISTER_OP("XRTWriteLiteral") + .Input("handle: int64") + .Input("literal: string") + .Output("output_handle: int64") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc( + R"( +Copies the input literal into the device memory pointed to by handle. +Returns the handle itself. + +'handle' is the id returned from the Op that produced the on-device allocation. +'literal' is a serialized xla::LiteralProto proto to be written to device memory. +)"); + REGISTER_OP("XRTReadLiteralAndRelease") .Input("handle: int64") .Output("literal: string") diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index b9262c1843a7ae48af49acbef5ba4ef58ec0f050..abaa17e50e3f5e47a45f5a8a45fa2090d3efee39 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -102,7 +102,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a, auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = l_a == l_b; if (!equal) { - LOG(INFO) << "LiteralProtos don't match " << a.DebugString() + LOG(INFO) << "LiteralProtos don't match: " << a.DebugString() << " != " << b.DebugString(); } return equal; @@ -175,6 +175,18 @@ xla::XlaComputation AddAndTuple() { return builder.Build().ValueOrDie(); } +xla::XlaComputation AddAndSubTuple() { + xla::XlaBuilder builder("AddAndSubTuple"); + auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P0"); + auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P1"); + auto sum = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + xla::Tuple(&builder, {sum, sub}); + return builder.Build().ValueOrDie(); +} + void StoreComputationSnapshot(const xla::XlaComputation& computation, xla::HloSnapshot* dst) { auto snapshot = computation.Snapshot().ValueOrDie(); @@ -203,6 +215,56 @@ xla::ProgramShape XlaCompiledProgramShape( ->ComputeProgramShape(); } +TEST(RawApiTest, AllocAndRewrite) { + xrt::XLAAllocation alloc; + alloc.set_device_ordinal(0); + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + auto read_back = ops::XRTReadLiteral(root, handle); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 allocation_handle = outputs[1].scalar()(); + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); + outputs.clear(); + + xla::LiteralProto new_literal = + xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto(); + auto new_value = ops::Const(root.WithDevice("/device:CPU:0"), + new_literal.SerializeAsString()); + auto write_op = + ops::XRTWriteLiteral(root, Input(allocation_handle), new_value); + TF_ASSERT_OK(root.status()); + TF_EXPECT_OK(session.Run({write_op}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(allocation_handle, outputs[0].scalar()()); + outputs.clear(); + + auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run({read_after_write}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto new_response; + EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); + + auto release = + ops::XRTReleaseAllocationHandle(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; alloc.set_device_ordinal(0); @@ -681,6 +743,70 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { + xrt::XLAAllocation p0; + p0.set_device_ordinal(0); + *p0.mutable_value() = xla::LiteralUtil::CreateR0(12.0f).ToProto(); + + xrt::XLAAllocation p1; + p1.set_device_ordinal(0); + *p1.mutable_value() = xla::LiteralUtil::CreateR0(3.0f).ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}), + xla::ShapeUtil::MakeShape(xla::F32, {})}) + .ToProto(); + StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + e.set_return_exploded_tuple(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle)}); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({result}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + auto handles_vec = outputs.front().vec(); + EXPECT_EQ(handles_vec.size(), 2); + + const float kResults[2] = {15.0f, 9.0f}; + for (int64 i = 0; i < handles_vec.size(); ++i) { + auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i))); + std::vector voutputs; + TF_EXPECT_OK(session.Run({read_back}, &voutputs)); + EXPECT_EQ(voutputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR0(kResults[i]); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + } +} + TEST(RawApiTest, LeakCompilationReference) { xrt::XLAComputation c; auto config = c.mutable_config(); diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index e149f2f43593ea412ef279b2c99dabac285cdac4..378bb9246f27b8106310d565435404d7ac260a87 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -101,4 +101,8 @@ message XRTExecutionConfig { bool release_input_handles = 5; // If true, release the handle to the computation after running. bool release_compilation_handle = 6; + // If set to true, and the result shape is a tuple, then instead of returning + // a single tuple allocation the execution will return a vector of + // allocations, one for each of the first-level elements of the result tuple. + bool return_exploded_tuple = 7; } diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 3a99820d7aa9e9546cc95385fd98c05f28988e9e..31603e044d17baa3ae0ae583f61837811bb12495 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_state.h" #include +#include #include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -41,6 +43,34 @@ namespace tensorflow { namespace { +class BufferAllocStats { + public: + struct Stats { + int64 count = 0; + int64 size = 0; + }; + + Stats ReportAlloc(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count += 1; + device_stats->size += msize; + return *device_stats; + } + + Stats ReportFree(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count -= 1; + device_stats->size -= msize; + return *device_stats; + } + + private: + mutable mutex lock_; + std::map stats_; +}; + const char* kTupleContainer = "tuples"; int64 get_uid() { @@ -48,6 +78,11 @@ int64 get_uid() { return static_cast(unsigned_rand); } +BufferAllocStats* GetAllocStats() { + static BufferAllocStats* stats = new BufferAllocStats(); + return stats; +} + Status AllocateScopedShapedBuffer( xla::Backend* backend, int device_ordinal, const xla::Shape& shape, std::unique_ptr* buffer) { @@ -100,9 +135,19 @@ XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation, xla::DeviceMemoryAllocator* allocator) : allocation_(allocation), device_ordinal_(device_ordinal), - allocator_(allocator) {} + allocator_(allocator) { + if (VLOG_IS_ON(2)) { + auto stats = + GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size()); + LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_ + << " count=" << stats.count << " size=" << stats.size; + } +} XRTBufferAllocation::~XRTBufferAllocation() { + if (VLOG_IS_ON(2)) { + GetAllocStats()->ReportFree(device_ordinal_, allocation_.size()); + } // Deallocate explicitly allows allocation_ to be null. Status s = allocator_->Deallocate(device_ordinal_, allocation_); // Nothing to do but check fail here if memory datastructures are corrupted. @@ -183,6 +228,20 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, return Status::OK(); } +Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, + const xla::Literal& literal) { + if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) { + return errors::InvalidArgument( + "New literal shape not matching the existing one: literal=", + xla::ShapeUtil::HumanStringWithLayout(literal.shape()), + " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape())); + } + auto transfer_manager = backend->transfer_manager(); + TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); + return transfer_manager->TransferLiteralToDevice(stream.get(), literal, + ToShapedBuffer()); +} + void XRTTupleAllocation::DiscardAllocation( const xla::ShapeIndex& buffer_index) { buffers_.element(buffer_index)->DiscardAllocation(); diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 73b5584e38f781343fe6793af7ad28232fbfc184..3664c0cd4e6ad26945ae1012208fdb006164a066 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -137,6 +137,9 @@ class XRTTupleAllocation : public ResourceBase { Status ToLiteral(xla::Backend* backend, int device_ordinal, xla::Literal* literal); + // Write a new literal value to the allocation. + Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); + // True if none of the buffers in the allocation are aliased by any other live // handle. bool IsExclusiveOwner(); diff --git a/tensorflow/contrib/android/cmake/build.gradle b/tensorflow/contrib/android/cmake/build.gradle index 17a57b99fd6c9efc09bda0ce1249b1f51bd5af5c..ddec08894f34f96b080610f1d27a6a436f7ffa91 100644 --- a/tensorflow/contrib/android/cmake/build.gradle +++ b/tensorflow/contrib/android/cmake/build.gradle @@ -22,8 +22,8 @@ android { } externalNativeBuild { cmake { - arguments '-DANDROID_TOOLCHAIN=gcc', - '-DANDROID_STL=gnustl_static' + arguments '-DANDROID_TOOLCHAIN=clang', + '-DANDROID_STL=c++_static' } } } @@ -70,7 +70,7 @@ if (ndkDir == null || ndkDir == "") { ndkDir = System.getenv('ANDROID_NDK_HOME') } -if(! Os.isFamily(Os.FAMILY_WINDOWS)) { +if (!Os.isFamily(Os.FAMILY_WINDOWS)) { // This script is for non-Windows OS. For Windows OS, MANUALLY build // (or copy the built) libs/headers to the // ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md index 2c44abed5e1955cc666273e97e6b2378766f13d2..79052bee35c7895cb4048b10c1f73acb036d1587 100644 --- a/tensorflow/contrib/bigtable/README.md +++ b/tensorflow/contrib/bigtable/README.md @@ -51,25 +51,18 @@ BIGTABLE_TABLE_NAME = '' PREFIX = 'train-' def main(): + tf.enable_eager_execution() + client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID) table = client.table(BIGTABLE_TABLE_NAME) dataset = table.keys_by_prefix_dataset(PREFIX) - iterator = dataset.make_initializable_iterator() - get_next_op = iterator.get_next() - with tf.Session() as sess: - print('Initializing the iterator.') - sess.run(iterator.initializer) - print('Retrieving rows:') - row_index = 0 - while True: - try: - row_key = sess.run(get_next_op) - print('Row key %d: %s' % (row_index, row_key)) - row_index += 1 - except tf.errors.OutOfRangeError: - print('Finished reading data!') - break + print('Retrieving rows:') + row_index = 0 + for row_key in dataset: + print('Row key %d: %s' % (row_index, row_key)) + row_index += 1 + print('Finished reading data!') if __name__ == '__main__': main() diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py index 316da9ebe152ef52c7e7f846cf8c3eb1555ee8a6..197f5578eb010bee5a3aad7c05446393193f99e2 100644 --- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -57,7 +57,7 @@ class BigtableOpsTest(test.TestCase): sess.run(write_op) def runReadKeyTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected = list(self.COMMON_ROW_KEYS) expected.reverse() @@ -78,7 +78,7 @@ class BigtableOpsTest(test.TestCase): self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) def runScanTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_keys.reverse() @@ -120,7 +120,7 @@ class BigtableOpsTest(test.TestCase): def testLookup(self): ds = self._table.keys_by_prefix_dataset("r") ds = ds.apply(self._table.lookup_columns(cf1="c1")) - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_values = list(self.COMMON_VALUES) @@ -141,7 +141,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeys(self): ds = self._table.sample_keys() - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_key = self.COMMON_ROW_KEYS[0] with self.cached_session() as sess: @@ -161,7 +161,7 @@ class BigtableOpsTest(test.TestCase): sess.run(n) def runSampleKeyPairsTest(self, ds, expected_key_pairs): - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -218,7 +218,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndStartKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="r1", end="") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) @@ -226,14 +226,14 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndEndKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="", end="r3") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) def testParallelScanPrefix(self): ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -251,7 +251,7 @@ class BigtableOpsTest(test.TestCase): def testParallelScanRange(self): ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index 9f97934193de2a110cbbb8ec88b83fbbea3b5b73..b6cdc7aab0320fe5f457288ada03a46e18a694cc 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -35,8 +35,8 @@ from tensorflow.contrib.util import loader from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import resource_loader @@ -111,8 +111,7 @@ class BigtableClient(object): class BigtableTable(object): - """BigtableTable is the entrypoint for reading and writing data in Cloud - Bigtable. + """Entry point for reading and writing data in Cloud Bigtable. This BigtableTable class is the Python representation of the Cloud Bigtable table within TensorFlow. Methods on this class allow data to be read from and @@ -593,16 +592,8 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource): self._table = table @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.TensorShape([]) - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) class _BigtablePrefixKeyDataset(_BigtableKeyDataset): @@ -662,16 +653,9 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource): self._columns = [i[1] for i in normalized] @property - def output_classes(self): - return tuple([ops.Tensor] * self._num_outputs) - - @property - def output_shapes(self): - return tuple([tensor_shape.TensorShape([])] * self._num_outputs) - - @property - def output_types(self): - return tuple([dtypes.string] * self._num_outputs) + def _element_structure(self): + return structure.NestedStructure(tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) def _as_variant_tensor(self): # pylint: disable=protected-access @@ -697,16 +681,9 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): self._num_outputs = len(normalized) + 1 # 1 for row key @property - def output_classes(self): - return tuple([ops.Tensor] * self._num_outputs) - - @property - def output_shapes(self): - return tuple([tensor_shape.TensorShape([])] * self._num_outputs) - - @property - def output_types(self): - return tuple([dtypes.string] * self._num_outputs) + def _element_structure(self): + return structure.NestedStructure(tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) def _as_variant_tensor(self): return gen_bigtable_ops.bigtable_scan_dataset( @@ -730,16 +707,10 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): self._end = end @property - def output_classes(self): - return (ops.Tensor, ops.Tensor) - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) - - @property - def output_types(self): - return (dtypes.string, dtypes.string) + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) def _as_variant_tensor(self): # pylint: disable=protected-access diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index 8b76f37858291eb6f44939356ed20a26ed371f26..46a193971c5084523d432065f265fa7a9909f595 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -31,17 +31,17 @@ if (systemlib_ABSEIL_CPP) message(STATUS " abseil_cpp includes: ${ABSEIL_CPP_INCLUDE_DIR}") message(STATUS " abseil_cpp libraries: ${ABSEIL_CPP_LIBRARIES}") - add_custom_target(abseil_cpp_build) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) + add_custom_target(abseil_cpp) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) else (systemlib_ABSEIL_CPP) include (ExternalProject) - set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) + set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp) set(abseil_cpp_URL https://github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz) set(abseil_cpp_HASH SHA256=84043ed402d2a2a6ba4cdddb7e85118b1158fd81fe4ac3a14adc343d054c1e2e) - set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) + set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp-build) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") @@ -77,15 +77,12 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) endif() - ExternalProject_Add(abseil_cpp_build + ExternalProject_Add(abseil_cpp PREFIX abseil_cpp URL ${abseil_cpp_URL} URL_HASH ${abseil_cpp_HASH} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" - BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${abseil_cpp_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release - COMMAND ${CMAKE_COMMAND} --build . --config Release INSTALL_COMMAND "" CMAKE_CACHE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} @@ -96,6 +93,6 @@ else (systemlib_ABSEIL_CPP) include_directories(${abseil_cpp_INCLUDE_DIR}) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${abseil_cpp_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) endif (systemlib_ABSEIL_CPP) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index eca4f3c8c8866ff60c4ee8332a2baaa972fe3b83..e570c09ecb5e64130ed6f3375a51d74850cc3989 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) +set(GRPC_TAG 69b6c047bc767b4d80e7af4d00ccb7c45b683dae) if(WIN32) # We use unsecure gRPC because boringssl does not build on windows diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 48dbfb92e65b0ed456846f83ddd5eed4d74dfe67..62005dd113bfb80fbdf23afb6d4aa5f90a1e32de 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -213,6 +213,10 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ # unsupported Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ DESTINATION include/unsupported/Eigen) +# absl directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/abseil_cpp/src/abseil_cpp/absl/ + DESTINATION include/absl + FILES_MATCHING PATTERN "*.h") # mkl if (tensorflow_ENABLE_MKL_SUPPORT) install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/ diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py index 0456463a1928cf226010670b90a5d574579e0411..6c5f8c6b00975b3fba041271309a93cecd9f5057 100644 --- a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py @@ -46,7 +46,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -88,7 +88,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -115,9 +115,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -142,7 +141,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): tensor_shape.TensorShape((3, 4))) self.assertEqual(actual_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -184,7 +183,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -211,9 +210,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((None, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index d2a72272db159755ac2d741bcdbce9ec646d928e..b9840b1ff1a3df5a05db0e64f436637220f49f80 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -23,6 +23,7 @@ import shutil from tensorflow.contrib.data.python.ops import readers from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -48,7 +49,7 @@ class LMDBDatasetTest(test_base.DatasetTestBase): num_repeats = 2 dataset = readers.LMDBDataset(filenames).repeat(num_repeats) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index c5a786232252432481566e3cde23e9310df172cc..2527706709fae8e459aca3489324d4db3c784be6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -63,13 +63,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> # _SlideDataset(window_size, window_shift, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -127,13 +127,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> _SlideDataset(window_size, stride, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, stride=stride_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -173,12 +173,12 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): window_shift_t = array_ops.placeholder(dtypes.int64, shape=[]) window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count_t).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer with self.cached_session() as sess: @@ -204,9 +204,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -233,9 +233,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): values=array_ops.fill([math_ops.to_int32(i)], i), dense_shape=[i]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -265,11 +265,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(_sparse).apply( sliding.sliding_window_batch(window_size=4, window_shift=2)).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) init_op = iterator.initializer get_next = iterator.get_next() @@ -305,11 +304,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): yield [4.0, 5.0, 6.0] yield [7.0, 8.0, 9.0, 10.0] - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_generator( generator, dtypes.float32, output_shapes=[None]).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) next_element = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index aa42782807ab7448105da617c619ff09a50b2680..c0152156a1ba70297adb7054622b15ca04f859cd 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -21,10 +21,9 @@ from tensorflow.python.data.experimental.ops import optimization from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.util import deprecation @@ -396,18 +395,10 @@ class LMDBDataset(dataset_ops.DatasetSource): def _as_variant_tensor(self): return gen_experimental_dataset_ops.experimental_lmdb_dataset( - self._filenames, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) + self._filenames, **dataset_ops.flat_structure(self)) @property - def output_classes(self): - return ops.Tensor, ops.Tensor - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) - - @property - def output_types(self): - return dtypes.string, dtypes.string + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 9ebdca317f2dd34d33e447c1ddd2368b91ee9be3..5c6ee6bfdc7167d14b292f8f763adafca4e3a72c 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -39,11 +39,10 @@ class _SlideDataset(dataset_ops.UnaryDataset): self._window_shift = ops.convert_to_tensor( window_shift, dtype=dtypes.int64, name="window_shift") - # pylint: disable=protected-access - input_structure = structure.Structure._from_legacy_structure( + input_structure = structure.convert_legacy_structure( input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) - self._output_structure = input_structure._batch(None) + self._structure = input_structure._batch(None) # pylint: disable=protected-access def _as_variant_tensor(self): return ged_ops.experimental_sliding_window_dataset( @@ -51,19 +50,11 @@ class _SlideDataset(dataset_ops.UnaryDataset): window_size=self._window_size, window_shift=self._window_shift, window_stride=self._window_stride, - **dataset_ops.flat_structure(structure=self._output_structure)) + **dataset_ops.flat_structure(self)) @property - def output_classes(self): - return self._output_structure._to_legacy_output_classes() # pylint: disable=protected-access - - @property - def output_shapes(self): - return self._output_structure._to_legacy_output_shapes() # pylint: disable=protected-access - - @property - def output_types(self): - return self._output_structure._to_legacy_output_types() # pylint: disable=protected-access + def _element_structure(self): + return self._structure @deprecation.deprecated_args( diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index e988b63a28718e509df0d5ce42423ba4616b0e60..5c50a20490482856becedf7b1379d2a0583d9a11 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -77,11 +77,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._num_workers = 1 if num_gpus_per_worker: - local_devices = [ + local_devices = tuple( "/device:GPU:%d" % i for i in range(num_gpus_per_worker) - ] + ) else: - local_devices = ["/device:CPU:0"] + local_devices = ("/device:CPU:0",) self._worker_device = device_util.canonicalize("/device:CPU:0") self._collective_keys = cross_device_utils.CollectiveKeys() @@ -104,7 +104,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): if task_type is None or task_id is None: raise ValueError("When `cluster_spec` is given, you must also specify " "`task_type` and `task_id`") - if task_type not in ["chief", "worker"]: + if task_type not in ("chief", "worker"): raise ValueError( "Unrecognized task_type: %r, valid task types are: \"chief\", " "\"worker\"." % task_type) @@ -119,12 +119,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._worker_device = "/job:%s/task:%d" % (task_type, task_id) if num_gpus_per_worker: - local_devices = [ + local_devices = tuple( "%s/device:GPU:%d" % (self._worker_device, i) for i in range(num_gpus_per_worker) - ] + ) else: - local_devices = [self._worker_device] + local_devices = (self._worker_device,) self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 00a220880bf88a8ed931a48c9dc6adf6bb0c618f..683cc89bfbae9c877ea6794d311ffc00c96c6937 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -27,7 +27,6 @@ from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import values -from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib @@ -166,7 +165,9 @@ def get_multi_inputs_multi_outputs_data(): return (train_data, test_data) -def batch_wrapper(dataset, batch_size, distribution): +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) # TPUs currently require fully defined input shapes, drop_remainder ensures # the input will have fully defined shapes. if isinstance(distribution, tpu_strategy.TPUStrategy): @@ -213,9 +214,11 @@ def multi_input_output_model(): return model -def get_correctness_test_inputs(use_numpy, with_distribution, +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, x_train, y_train, x_predict): """Generates the inputs for correctness check when enable Keras with DS.""" + training_epochs = 2 global_batch_size = 64 batch_size = global_batch_size # TODO(b/118776054): Use global batch size for Keras/DS support. @@ -231,14 +234,19 @@ def get_correctness_test_inputs(use_numpy, with_distribution, 'batch_size': batch_size, 'x': x_train, 'y': y_train, - 'epochs': 1, + 'epochs': training_epochs, 'shuffle': False, } - eval_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } predict_inputs = { 'x': np.array(x_predict, dtype=np.float32), } @@ -247,22 +255,32 @@ def get_correctness_test_inputs(use_numpy, with_distribution, # keras.fit/evaluate/predict. The batch size is part of the dataset. train_dataset = dataset_ops.Dataset.from_tensor_slices( (x_train, y_train)) - x = batch_wrapper(train_dataset, batch_size, with_distribution) + x = batch_wrapper( + train_dataset, batch_size, with_distribution, repeat=training_epochs) training_inputs = { 'batch_size': None, 'x': x, 'y': None, - 'epochs': 1, + 'epochs': training_epochs, 'shuffle': False, 'steps_per_epoch': len(x_train) // global_batch_size, } - eval_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'steps': 20, - } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + predict_batch_size = len(x_predict) if use_per_core_batch_size: predict_batch_size //= with_distribution.num_replicas_in_sync @@ -321,11 +339,17 @@ def strategy_and_input_combinations(): return ( combinations.times( combinations.combine(distribution=strategies_minus_tpu), - combinations.combine(mode=['graph'], use_numpy=[True, False]) - + combinations.combine(mode=['eager'], use_numpy=[False])) + + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]) + + combinations.combine(mode=['eager'], + use_numpy=[False], + use_validation_data=[False])) + combinations.times( combinations.combine(distribution=tpu_strategies), - combinations.combine(mode=['graph'], use_numpy=[True, False]))) + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]))) def strategy_for_numpy_input_combinations(): @@ -1061,8 +1085,8 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) def test_unsupported_features(self, distribution): with self.cached_session(): @@ -1110,8 +1134,8 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) def test_calling_with_unsupported_predefined_callbacks(self, distribution): with self.cached_session(): @@ -1137,12 +1161,6 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'using'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.ReduceLROnPlateau()]) - with self.assertRaisesRegexp(ValueError, - 'histogram_freq in the TensorBoard callback ' - 'is not supported when using ' - 'DistributionStrategy.'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) class TestDistributionStrategyWithLossMasking(test.TestCase, @@ -1210,8 +1228,7 @@ class TestDistributionStrategyWithNormalizationLayer( class TestDistributionStrategyCorrectness(test.TestCase, parameterized.TestCase): - # TODO(b/120186218): Enable this for eager once metrics are working. - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_metric_correctness(self, distribution): with self.cached_session(): keras.backend.set_image_data_format('channels_last') @@ -1240,18 +1257,57 @@ class TestDistributionStrategyCorrectness(test.TestCase, train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) - self.assertEqual(history.history['binary_accuracy'], [1.0]) + history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) + + @combinations.generate(all_strategy_combinations()) + def test_eval_metrics_correctness(self, distribution): + with self.cached_session(): + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', keras.metrics.BinaryAccuracy()], + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + distribute=distribution) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)).astype('float32') + y = np.ones((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) @combinations.generate(strategy_and_input_combinations()) - def test_correctness(self, distribution, use_numpy): + def test_correctness(self, distribution, use_numpy, use_validation_data): + with self.cached_session(): - tolerance = 1e-5 + default_tolerance = 1e-5 + tol_table = {} if isinstance(distribution, (mirrored_strategy.MirroredStrategy, mirrored_strategy.CoreMirroredStrategy)): - # TODO(b/119257215): use the default one once the flakyness is fixed. - tolerance = 1e-4 + # TODO(b/119257215): Weights are not exactly the same, so use larger + # tolerance for now. Predict should be related to weights. + tol_table = { + 'weights_1': 1e-4, + 'weights_2': 1e-4, + 'predict_result_1': 1e-4, + } keras.backend.set_image_data_format('channels_last') np.random.seed(_RANDOM_SEED) @@ -1272,49 +1328,75 @@ class TestDistributionStrategyCorrectness(test.TestCase, # This is used to initialize the model for both the distribution and # non-distribution run. In addition, we add few non-linear layers to make # it non-trivial. - model = keras.Sequential() - model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(1)) + def _create_model(): + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + return model + + model = _create_model() initial_weights = model.get_weights() + del model # avoid accident usage. def fit_eval_and_predict(with_distribution=None): + model = _create_model() # We have initialized the model to the same weight for the distribution # and non-distribution run. model.set_weights(initial_weights) model.compile( loss=keras.losses.mean_squared_error, optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse'], distribute=with_distribution) training_inputs, eval_inputs, predict_inputs = ( - get_correctness_test_inputs(use_numpy, with_distribution, + get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, x_train, y_train, x_predict)) - model.fit(**training_inputs) - eval_result = model.evaluate(**eval_inputs) - weights = model.get_weights() - predict_result = model.predict(**predict_inputs) + result = {} + result['training_history_1'] = model.fit(**training_inputs).history - return weights, eval_result, predict_result + if eval_inputs is not None: + result['eval_result_1'] = model.evaluate(**eval_inputs) - wts_with_ds, eval_with_ds, predict_with_ds = fit_eval_and_predict( - with_distribution=distribution) - wts_without_ds, eval_without_ds, predict_without_ds = ( - fit_eval_and_predict(with_distribution=None)) + result['weights_1'] = model.get_weights() + result['predict_result_1'] = model.predict(**predict_inputs) - # Verify that the weights, eval results, predict outputs are the same - # within some limits of tolerance. - self.assertAllClose( - wts_with_ds, wts_without_ds, atol=tolerance, rtol=tolerance) + # Train and eval again to mimic user's flow. + + result['training_history_2'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_2'] = model.evaluate(**eval_inputs) + + result['weights_2'] = model.get_weights() + + return result + + results_with_ds = fit_eval_and_predict(with_distribution=distribution) + results_without_ds = fit_eval_and_predict(with_distribution=None) + + # Verify that the weights, training history, eval results, predict outputs + # are the same within some limits of tolerance. + for key in results_with_ds: + if (key.startswith('training_history') and + isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the + # underlying bug is fixed. + continue + + tolerance = tol_table.get(key, default_tolerance) - # TODO(b/120186218): Enable this once metrics are working with eager. - if not context.executing_eagerly(): - self.assertAllClose( - eval_with_ds, eval_without_ds, atol=tolerance, rtol=tolerance) self.assertAllClose( - predict_with_ds, predict_without_ds, atol=tolerance, rtol=tolerance) + results_with_ds[key], + results_without_ds[key], + atol=tolerance, + rtol=tolerance, + msg='Fail to assert {}.'.format(key)) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index dcc9df4cda51b87e95fb166a726170a8817715fc..f09483cb56b66fd4720ee71085203c14f1ccadc3 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -232,7 +232,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): fetches = distribution.unwrap( distribution.call_for_each_replica(model_fn, args=inputs)) if update_ops_in_cross_replica_mode: - fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) + fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) return control_flow_ops.group(fetches) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -443,7 +443,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): step_fn, iterator, iterations=2, initial_loop_values=initial_loop_values) - self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) + self.assertEqual({key1: (value1,)}, ctx.non_tensor_outputs) self._verify_loss_output( initial_loss(), loss_output=ctx.last_step_outputs["replica_loss_reduced"], diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 66512f983e1c80e0c7937d104cd4f73bfd934eb8..36be5c83f8bafb6c934d1d7682b5227b1f71c089 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -180,7 +180,7 @@ class MirroredStrategyVariableCreatorStackTest( variable_scope.variable_creator_scope(main_thread_creator): result = distribution.extended.call_for_each_replica(model_fn) result = distribution.unwrap(result) - expected = ["main_thread:thread_0", "main_thread:thread_1"] + expected = ("main_thread:thread_0", "main_thread:thread_1") self.assertEqual(expected, result) diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index c492d8bafc9024ed059f05b92e5466f3702726b9..8f13e9153ea7a951dd722c4549882c97e79b57fe 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -139,6 +139,27 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)], var.eval()) + @combinations.generate(all_combinations) + def testAssignVariable(self, distribution): + + def replica_fn(): + var = variables.Variable([10.0, 11.0]) + # Here we expect to check the case when input value are variable. + val = variables.Variable([1., 2.]) + decay = 0.25 + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + return var, assign + + with distribution.scope(), self.cached_session() as sess: + var, assign = distribution.call_for_each_replica(replica_fn) + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(distribution.unwrap(assign)) + self.assertAllClose( + [10 * 0.25 + 1. * (1 - 0.25), 11 * 0.25 + 2. * (1 - 0.25)], + var.eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index e322b6acb84c166a885c9aaa3002f331903a5063..fdbfba4e04358451a46b23ef250dc7c534c855a0 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -60,7 +60,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): if isinstance(colocate_with, six.string_types): with ops.device(colocate_with): return next_creator(*args, **kwargs) - if (isinstance(colocate_with, list) and len(colocate_with) == 1 and + if (isinstance(colocate_with, (list, tuple)) and len(colocate_with) == 1 and isinstance(colocate_with[0], six.string_types)): with ops.device(colocate_with[0]): return next_creator(*args, **kwargs) @@ -166,7 +166,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): return array_ops.identity(replica_local_var) def _unwrap(self, value): - return [value] + return (value,) def value_container(self, value): return value @@ -177,15 +177,15 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): @property def worker_devices(self): - return [self._device] + return (self._device,) @property def parameter_devices(self): - return [self._device] + return (self._device,) def non_slot_devices(self, var_list): del var_list - return [self._device] + return (self._device,) @property def experimental_should_init(self): @@ -216,4 +216,4 @@ class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): @property def devices(self): - return [self._distribution_strategy.extended.worker_devices[0]] + return self._distribution_strategy.extended.worker_devices diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index eaeb4d703015fc0762359b24dc23888c01e69111..2c7766f95fbcb7b68a53ad0052f21485c763a1db 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -145,14 +145,14 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): # replica. When there are GPUs, replicate operations on these GPUs. # Otherwise, place operations on CPU. if num_gpus_per_worker > 0: - self._compute_devices = [ + self._compute_devices = tuple( "%s/device:GPU:%d" % (self._worker_device, i) for i in range(num_gpus_per_worker) - ] + ) else: - self._compute_devices = [self._worker_device] + self._compute_devices = (self._worker_device,) - self._compute_devices = list( + self._compute_devices = tuple( map(device_util.resolve, self._compute_devices)) self._canonical_compute_device_set = set(self._compute_devices) @@ -176,8 +176,8 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): # The `_parameter_devices` is needed for the `parameter_devices` property # and is a list of all variable devices. Here parameter devices are all # tasks of the "ps" job. - self._parameter_devices = map("/job:ps/task:{}".format, - range(num_ps_replicas)) + self._parameter_devices = tuple(map("/job:ps/task:{}".format, + range(num_ps_replicas))) # Add a default device so that ops without specified devices will not end up # on other workers. @@ -204,24 +204,24 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): # replica. When there are GPUs, replicate operations on these GPUs. # Otherwise, place operations on CPU. if num_gpus_per_worker > 0: - self._compute_devices = list( + self._compute_devices = tuple( map("/device:GPU:{}".format, range(num_gpus_per_worker))) else: - self._compute_devices = [_LOCAL_CPU] + self._compute_devices = (_LOCAL_CPU,) - self._compute_devices = list( + self._compute_devices = tuple( map(device_util.resolve, self._compute_devices)) self._canonical_compute_device_set = set(self._compute_devices) # If there is only one GPU, put everything on that GPU. Otherwise, place # variables on CPU. if num_gpus_per_worker == 1: - assert len(list(self._compute_devices)) == 1 + assert len(self._compute_devices) == 1 self._variable_device = _LOCAL_GPU_0 - self._parameter_devices = [_LOCAL_GPU_0] + self._parameter_devices = (_LOCAL_GPU_0,) else: self._variable_device = _LOCAL_CPU - self._parameter_devices = [_LOCAL_CPU] + self._parameter_devices = (_LOCAL_CPU,) self._is_chief = True self._cluster_spec = None @@ -417,9 +417,9 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): if isinstance(val, values.DistributedValues): # Return in a deterministic order. if set(val.devices) == self._canonical_compute_device_set: - return [val.get(device=d) for d in self._compute_devices] - return [val.get(device=d) for d in sorted(val.devices)] - return [val] + return tuple(val.get(device=d) for d in self._compute_devices) + return tuple(val.get(device=d) for d in sorted(val.devices)) + return (val,) def value_container(self, val): if (hasattr(val, "_aggregating_container") and @@ -497,12 +497,11 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): @property def worker_devices(self): - # Make a copy to prevent users from accidentally mutating our copy. - return list(self._compute_devices) + return self._compute_devices @property def parameter_devices(self): - return list(self._parameter_devices) + return self._parameter_devices def non_slot_devices(self, var_list): return min(var_list, key=lambda x: x.name) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index d50b142c5e9ad36522b11a77219140a7b40d9bf6..d441b5af5f6aa41efde2c75d09d9589516c54992 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -290,4 +290,4 @@ class DistributionTestBase(test.TestCase): self.evaluate(strategy.group(train_ops)) global_step_tensors = strategy.unwrap(value) global_step_values = self.evaluate(global_step_tensors) - self.assertEqual([1] * len(global_step_tensors), global_step_values) + self.assertEqual((1,) * len(global_step_tensors), global_step_values) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 39ed8f7cf10371c0e8dd70e2bdf53f13e8ce8383..b6f5b492017fc7dfd329e69ad9ca418ae682bc4b 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -28,6 +28,8 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as session_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib @@ -43,12 +45,10 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -_TPU_INITIALIZE_SYSTEM_COLLECTION = "TPU_STRATEGY_INITIALIZE" - - def get_tpu_system_metadata(tpu_cluster_resolver): """Retrieves TPU system metadata given a TPUClusterResolver.""" master = tpu_cluster_resolver.master() @@ -145,6 +145,9 @@ class TPUStrategy(distribute_lib.DistributionStrategy): class TPUExtended(distribute_lib.DistributionStrategyExtended): """Implementation of TPUStrategy.""" + # Track what TPU devices have been initialized. + _initialized_devices = [] + def __init__(self, container_strategy, tpu_cluster_resolver, steps_per_run, num_cores=None): super(TPUExtended, self).__init__(container_strategy) @@ -159,16 +162,41 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): if "device:TPU:" in d.name} self._device_index = values.PerReplica(device_map) self._host_device = self.get_host_cpu_device(0) - self._tpu_devices = sorted(device_map.keys()) + self._tpu_devices = tuple(sorted(device_map.keys())) # Only create variables for the number of replicas we're running. self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run - self._require_static_shapes = True + # Initialize the TPU devices. + self._initialize_tpu() + + def _initialize_tpu(self): + """Initialize the TPU devices in a separate session and graph. + + We keep track of all the TPU devices that we're initialized as we should + only be running TPU initialize once for the entire process. + """ + master = self._tpu_cluster_resolver.master() + # Verify TPU has not already been initialized in this process. + if master in TPUExtended._initialized_devices: + logging.info("TPU master %s has already been initialized." % master) + return + + logging.info("Initializing the TPU system.") + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + self._configure(session_config) + with ops.Graph().as_default(): + with session_lib.Session(config=session_config, target=master) as sess: + sess.run([tpu.initialize_system()]) + logging.info("Finized initializing TPU system.") + + # Update Strategy state to make sure we can track device initialization. + TPUExtended._initialized_devices.append(master) + def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator, input_shapes, iterations): """Create an enqueue op for a single host identified using host_id. @@ -380,22 +408,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError("Eager mode not supported in TPUStrategy.") else: - # TODO(jhseu): We need this hack because DistributionStrategies must be - # pickleable for copy.deepcopy(). Remove when initialize_system goes away. - graph = ops.get_default_graph() - tpu_init = graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) - if tpu_init: - return tpu_init - graph.add_to_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION, - tpu.initialize_system()) - return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) + return [] def _finalize(self): if context.executing_eagerly(): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError("Eager mode not supported in TPUStrategy.") else: - return [tpu.shutdown_system()] + return [] def _get_devices_from(self, colocate_with=None): # TODO(jhseu): Change this when we support model parallelism. @@ -487,13 +507,13 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def _unwrap(self, val): if isinstance(val, values.DistributedValues): # Return in a deterministic order. - return [val.get(device=d) for d in sorted(val.devices)] + return tuple(val.get(device=d) for d in sorted(val.devices)) elif isinstance(val, list): # TODO(josh11b): We need to remove this case; per device values should # be represented using a PerReplica wrapper instead of a list with # one entry per device. - return val - return [val] + return tuple(val) + return (val,) def value_container(self, value): return value @@ -599,4 +619,4 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): distribute_lib.require_replica_context(self) ds = self._distribution_strategy replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) - return [ds.extended.worker_devices[replica_id]] + return (ds.extended.worker_devices[replica_id],) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py index 557ad42752144243ae3da61b955b31398cba846e..d412b25b368260b81256fd58034330b884261b2b 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py @@ -36,7 +36,7 @@ class GraphLinearRegressionBenchmark(tf.test.Benchmark): noise_level=0.01, batch_size=batch_size, num_batches=num_batches) - iterator = dataset.make_initializable_iterator() + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) x, y = iterator.get_next() model = linear_regression.LinearModel() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 33c988fd9065e7fbe7b9aeb85cad82eb3c119f76..8882a863c30d8b222c68d6952279c3744345883c 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -41,6 +41,8 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@add_execution_callback @@clear_execution_callbacks +@@errstate +@@ExecutionCallback @@inf_callback @@inf_nan_callback @@nan_callback @@ -119,6 +121,8 @@ from tensorflow.python.eager.context import set_server_def from tensorflow.python.eager.def_function import function from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks +from tensorflow.python.eager.execution_callbacks import errstate +from tensorflow.python.eager.execution_callbacks import ExecutionCallback from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py index e2594faf85bcf91cbe09f266e4d4211d20bdee17..364fa4eb461c62784803f0c309e3b7c5855df199 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py @@ -64,6 +64,9 @@ def condition_tensor(tensor, conditioning): """ tensor.shape[1:].assert_is_fully_defined() num_features = tensor.shape[1:].num_elements() + if conditioning.shape.ndims < 2: + raise ValueError('conditioning must be at least 2D, but saw shape: %s' + % conditioning.shape) mapped_conditioning = layers.linear( layers.flatten(conditioning), num_features) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py index 0aad769793761be69ee9d1e3416e44c7b3d8cea0..f5c7d53cf2c9aa08ba0074950983ef3ecd90168b 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py @@ -45,7 +45,7 @@ class ConditioningUtilsTest(test.TestCase): array_ops.placeholder(dtypes.float32, (5, None)), array_ops.placeholder(dtypes.float32, (5, 1))) - with self.assertRaisesRegexp(ValueError, 'expected min_ndim=2'): + with self.assertRaisesRegexp(ValueError, 'at least 2D'): conditioning_utils.condition_tensor( array_ops.placeholder(dtypes.float32, (5, 2)), array_ops.placeholder(dtypes.float32, (5))) diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py index f7f1189bb93c611719186a697c40f371644f63a2..bc941ae9f23eaa5c46fcca95b9aba0ac0d87960a 100644 --- a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py +++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os from tensorflow.contrib.hadoop.python.ops import hadoop_dataset_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -47,7 +48,7 @@ class SequenceFileDatasetTest(test.TestCase): dataset = hadoop_dataset_ops.SequenceFileDataset(filenames).repeat( num_repeats) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index d3fcc8cb2a937278b4c24d6bee9a24033e55a561..5c5599858ee6879a5703d65658bf4bbd881c7e72 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -20,10 +20,9 @@ from __future__ import print_function from tensorflow.contrib.hadoop.python.ops import gen_dataset_ops from tensorflow.contrib.hadoop.python.ops import hadoop_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape class SequenceFileDataset(dataset_ops.DatasetSource): @@ -57,16 +56,10 @@ class SequenceFileDataset(dataset_ops.DatasetSource): def _as_variant_tensor(self): return gen_dataset_ops.sequence_file_dataset( - self._filenames, nest.flatten(self.output_types)) + self._filenames, self._element_structure._flat_types) # pylint: disable=protected-access @property - def output_classes(self): - return ops.Tensor, ops.Tensor - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) - - @property - def output_types(self): - return dtypes.string, dtypes.string + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py index 936b29a4f50794380d48efed99e267c6b4c44dc6..e4762c91b193f9c5e32fa2642e702e61e8e5e57f 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -27,6 +27,7 @@ import six from tensorflow.contrib.ignite.python.ops import gen_dataset_ops from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -34,10 +35,7 @@ from tensorflow.python.framework import tensor_shape @six.add_metaclass(abc.ABCMeta) class Readable(object): - """Readable abstract class that exposes methods to do reading-related - - operations. - """ + """Abstract class that exposes methods to do reading-related operations.""" @abc.abstractmethod def __init__(self): @@ -227,10 +225,7 @@ types = { class TypeTreeNode(object): - """TypeTreeNode class exposes methods to format object tree structure - - data. - """ + """TypeTreeNode class exposes methods to format object tree structure data.""" def __init__(self, name, type_id, fields=None, permutation=None): """Constructs a new instance of TypeTreeNode. @@ -692,14 +687,14 @@ class IgniteClient(TcpClient): class IgniteDataset(dataset_ops.DatasetSource): - """Apache Ignite is a memory-centric distributed database, caching, and - - processing platform for transactional, analytical, and streaming workloads, - delivering in-memory speeds at petabyte scale. This contrib package - contains an integration between Apache Ignite and TensorFlow. The - integration is based on tf.data from TensorFlow side and Binary Client - Protocol from Apache Ignite side. It allows to use Apache Ignite as a - datasource for neural network training, inference and all other + """Apache Ignite is a memory-centric distributed database. + + It acts as a caching and processing platform for transactional, analytical, + and streaming workloads, delivering in-memory speeds at petabyte scale. + This contrib package contains an integration between Apache Ignite and + TensorFlow. The integration is based on tf.data from TensorFlow side and + Binary Client Protocol from Apache Ignite side. It allows to use Apache + Ignite as a datasource for neural network training, inference and all other computations supported by TensorFlow. Ignite Dataset is based on Apache Ignite Binary Client Protocol. """ @@ -756,6 +751,9 @@ class IgniteDataset(dataset_ops.DatasetSource): self.cache_type.to_permutation(), dtype=dtypes.int32, name="permutation") + self._structure = structure.convert_legacy_structure( + self.cache_type.to_output_types(), self.cache_type.to_output_shapes(), + self.cache_type.to_output_classes()) def _as_variant_tensor(self): return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port, @@ -763,13 +761,5 @@ class IgniteDataset(dataset_ops.DatasetSource): self.schema, self.permutation) @property - def output_classes(self): - return self.cache_type.to_output_classes() - - @property - def output_shapes(self): - return self.cache_type.to_output_shapes() - - @property - def output_types(self): - return self.cache_type.to_output_types() + def _element_structure(self): + return self._structure diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index 4997c31a7fc7f4243d03b22fc9c01fb13a2a25a4..ba5cdfebf92c07e496ed588848d5859ff6a5bff2 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -281,6 +281,13 @@ class ImageOpsTest(test_util.TensorFlowTestCase): value.eval(), np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype())) + @test_util.run_in_graph_and_eager_modes + def test_transform_eager(self): + image = constant_op.constant([[1., 2.], [3., 4.]]) + value = image_ops.transform(image, [1] * 8) + with self.test_session(use_gpu=True): + self.assertAllEqual(self.evaluate(value), np.array([[4, 4], [4, 4]])) + class BipartiteMatchTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index d4fb99a017faebe30384d739f22f4ff5fa986bc4..b25a6f7b5742917a032946fe03a0dab20e7dc1ad 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.contrib.image.ops import gen_image_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import common_shapes @@ -271,8 +272,11 @@ def transform(images, raise TypeError("Images should have rank between 2 and 4.") if output_shape is None: - output_shape = tensor_util.constant_value( - array_ops.shape(images)[1:3]) or array_ops.shape(images)[1:3] + output_shape = array_ops.shape(images)[1:3] + if not context.executing_eagerly(): + output_shape_value = tensor_util.constant_value(output_shape) + if output_shape_value is not None: + output_shape = output_shape_value output_shape = ops.convert_to_tensor( output_shape, dtypes.int32, name="output_shape") diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index 7129f09e8b42e48a9c768fd4a66cde3d4da9d31d..2b86331099ccae03664462987ee0c141d766c10f 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -20,9 +20,9 @@ from __future__ import print_function from tensorflow.contrib.kafka.python.ops import gen_dataset_ops from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape class KafkaDataset(dataset_ops.DatasetSource): @@ -63,13 +63,5 @@ class KafkaDataset(dataset_ops.DatasetSource): self._group, self._eof, self._timeout) @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.scalar() - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py index c392adbb1d925f26515a4a4b47d75a6125b2be81..20395395281768ac429984a1e3552cfd187527a2 100644 --- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -20,9 +20,9 @@ from __future__ import print_function from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape class KinesisDataset(dataset_ops.DatasetSource): @@ -81,13 +81,5 @@ class KinesisDataset(dataset_ops.DatasetSource): self._stream, self._shard, self._read_indefinitely, self._interval) @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.scalar() - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 0a4d2c6d4cb5cad7da93cea89478bc0fca2ac4d6..d791418c9d0f887058ceb535092fa8122da1aa75 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1459,13 +1459,6 @@ class DropoutTest(test.TestCase): class FlattenTest(test.TestCase): - def testInvalidRank(self): - with ops.Graph().as_default() as g, self.session(g): - inputs = array_ops.placeholder(dtype=dtypes.float32) - inputs.set_shape(tensor_shape.TensorShape((5,))) - with self.assertRaisesRegexp(ValueError, 'incompatible with the layer'): - _layers.flatten(inputs) - def testUnknownLastDim(self): with ops.Graph().as_default() as g, self.session(g): inputs = array_ops.placeholder(dtype=dtypes.float32) @@ -1502,6 +1495,12 @@ class FlattenTest(test.TestCase): images.get_shape().num_elements()) self.assertEqual(output.get_shape()[0], images.get_shape()[0]) + def testFlatten0D(self): + with self.cached_session(): + scalars = random_ops.random_uniform((5,), seed=1, name='scalars') + output = _layers.flatten(scalars) + self.assertEqual(output.shape, (5, 1)) + def testFlattenBatchSize(self): height, width = 3, 3 with self.cached_session() as sess: diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 8466dc36d13e223aed4f1dfe8e39a6f91c99fa55..d49834dc860a8b4341ddd3720fde52281f7474f7 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for SdcaModel.""" +"""Tests for SdcaModel (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index f3f1dcd98db5ae24af154d1f0851a0688d2bc611..c056a12fa5307a7e9ac4cf30e1386ddfd5cd7d75 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Proximal stochastic dual coordinate ascent optimizer for linear models.""" +# pylint: disable=line-too-long +"""Proximal stochastic dual coordinate ascent optimizer for linear models (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" +# pylint: enable=line-too-long from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -40,6 +47,7 @@ from tensorflow.python.ops import variables as var_ops from tensorflow.python.ops.nn import log_poisson_loss from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits from tensorflow.python.summary import summary +from tensorflow.python.util import deprecation __all__ = ['SdcaModel'] @@ -48,7 +56,7 @@ __all__ = ['SdcaModel'] class SdcaModel(object): """Stochastic dual coordinate ascent solver for linear models. - Loss functions supported: + Loss functions supported: * Binary logistic loss * Squared loss @@ -109,6 +117,10 @@ class SdcaModel(object): ``` """ + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, examples, variables, options): """Create a new sdca optimizer.""" diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index a001555e8f257c88a52fdb40d4181f5cd9c92e84..a28394964a12013c43d85701b5a0ab5c559afd62 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sharded mutable dense hash table.""" +"""Sharded mutable dense hash table (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division @@ -28,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import deprecation # TODO(rohanj): This should subclass Checkpointable and implement @@ -45,6 +51,10 @@ class ShardedMutableDenseHashTable(object): # TODO(andreasst): consider moving this to lookup module + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, key_dtype, value_dtype, diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py index 2b56d0fa3a8b8564b7c73a62bd99cc900d6f5c54..2d1457f9e4cc576da696be191e718814dd9ff4e5 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for sharded_mutable_dense_hashtable.py.""" +"""Tests for sharded_mutable_dense_hashtable.py (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py index 003795233ff2b28e33fc10388ef25efb63c43bb0..64730f8eed1ff9bfcd4a980dceb28abb98e39f73 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sparse feature column.""" +"""Sparse feature column (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework.ops import internal_convert_to_tensor from tensorflow.python.framework.ops import name_scope +from tensorflow.python.util import deprecation class SparseFeatureColumn(object): @@ -68,6 +74,10 @@ class SparseFeatureColumn(object): @@feature_values """ + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, example_indices, feature_indices, feature_values): """Creates a `SparseFeatureColumn` representation. diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py index 51c4f68543da2f563481cc2d35b556796616cf9d..0ae780e1a100c7dadde7196803f2ae0d4bcb2334 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for sparse_feature_column.py.""" +"""Tests for sparse_feature_column.py (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index b396c527673902d61072dc9cf7d2766476be8369..2a5232b476712a96f84be0f4725beb78bc138297 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -30,11 +30,13 @@ EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -# Note: The Protobuf source in `tensorflow/workspace.bzl` in TensorFlow -# 1.10 branch does not work. `make distclean` fails and blocks the build -# process. For now we're hardcoding to the version which is used by -# TensorFlow 1.9. -PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz" + +# Note: The protobuf repo needs to be cloned due to its submodules. +# These variables contain the GitHub repo and the sha, from `tensorflow/workspace.bzl`, +# from which to clone it from and checkout to. +readonly PROTOBUF_REPO="https://github.com/protocolbuffers/protobuf.git" +readonly PROTOBUF_TAG="$(grep -o 'https://github.com/protocolbuffers/protobuf/archive/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1 | awk '{print substr($0, index($0, "archive") + 8, index($0, "tar") - index($0, "archive") - 9) }')" + # TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' once # the archive has been propagated in mirror.bazel.build. RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" @@ -91,11 +93,34 @@ download_and_extract() { find "${dir}" -type f -name '*BUILD' -delete } +function clone_repository() { + local repo_url="${1}" + local destination_directory="${2}" + local commit_sha="${3}" + + if [[ -d "${destination_directory}" ]]; then + rm -rf "${destination_directory}" + fi + + git clone "${repo_url}" "${destination_directory}" + + pushd "$(pwd)" 1>/dev/null + + cd "${destination_directory}" + + if [[ -n "${commit_sha}" ]]; then + git checkout "${PROTOBUF_TAG}" + fi + + git submodule update --init + + popd 1>/dev/null +} + download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen" download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest" download_and_extract "${NSYNC_URL}" "${DOWNLOADS_DIR}/nsync" -download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf" download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2" download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d" download_and_extract "${DOUBLE_CONVERSION_URL}" "${DOWNLOADS_DIR}/double_conversion" @@ -106,6 +131,8 @@ download_and_extract "${CUB_URL}" "${DOWNLOADS_DIR}/cub/external/cub_archive" download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +clone_repository "${PROTOBUF_REPO}" "${DOWNLOADS_DIR}/protobuf" "${PROTOBUF_TAG}" + replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#static uint32x2_t p2ui_CONJ_XOR;// = vld1_u32( conj_XOR_DATA ); - Removed by scripts#' \ diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index 062deb74b165329d8e72efa73b9d81f4174f8831..9aabc4bec3053871e3ff6cd3a88fd76d293f48cc 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics_impl from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribution_strategy_context # TODO(nsilberman): move into metrics/python/ops/ diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 6c203e5519e6a66d20e2509eca3c74eb66bf32c7..fa1a7aaff0aa59a6a64b1f0bf836a273926d785d 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import optimizer from tensorflow.python.training import saver from tensorflow.python.training import session_run_hook +from tensorflow.python.training.saving import saveable_object_util LOCAL_VARIABLE_NAME = 'local_center_variable' GLOBAL_VARIABLE_NAME = 'global_center_variable' @@ -424,7 +425,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.trainable_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) swapped_var_list = {} for key, var in var_list.items(): @@ -464,4 +465,4 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): def after_create_session(self, session, coord): """Run initialization ops""" - session.run(self._variable_init_op) \ No newline at end of file + session.run(self._variable_init_op) diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index b7fd2d2fb9db3eed15eb1cc2934199939790b1c0..bf3e5c51f78cc3ca3c7c77009c9cf428c4988953 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages from tensorflow.python.training import optimizer from tensorflow.python.training import saver +from tensorflow.python.training.saving import saveable_object_util class MovingAverageOptimizer(optimizer.Optimizer): @@ -165,7 +166,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.global_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) v_name_to_tensor = {} for k, tensor_or_list in six.iteritems(var_list): diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index 200b0d200826a6212a236680327f4daf7d07831f..8b8065c678e11e8fc237e71cf1d392ced5c22ada 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -59,6 +59,23 @@ class DecoupledWeightDecayExtension(object): Note that this extension decays weights BEFORE applying the update based on the gradient, i.e. this extension only has the desired behaviour for optimizers which do not depend on the value of'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + schedule = tf.train.piecewise_constant(tf.train.get_global_step(), + [10000, 15000], [1e-0, 1e-1, 1e-2]) + lr = 1e-1 * schedule() + wd = lambda: 1e-4 * schedule() + + # ... + + optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr, + weight_decay=wd, + momentum=0.9, + use_nesterov=True) + ``` """ def __init__(self, weight_decay, **kwargs): diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 73a556f0b299614b098ceef0fb9d32f148227b03..7fb23abc38d9dc101204ed83808aebe5a8ef1e78 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -25,6 +25,7 @@ import abc import six from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -36,7 +37,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribution_strategy_context as distribute_ctx from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable @@ -997,10 +997,10 @@ class OptimizerV2(optimizer_v1.Optimizer): with ops.control_dependencies([update_ops]): finish_updates = distribution.extended.update_non_slot( non_slot_devices, finish, group=False) - # We said grouped=False, which means finish_updates is always a list. - # It will be [None] when finish() returns None. - if finish_updates == [None]: - finish_updates = [update_ops] + # We said group=False, which means finish_updates is always a tuple. + # It will be (None,) when finish() returns None. + if finish_updates == (None,): + finish_updates = (update_ops,) # Update `global_step` (if any). if global_step is None: diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index d50b52b8ff1ce8188ab52c6968d716378efd9daa..53a3bc63e1d770b451846c45370fdee9ffa72d70 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -42,6 +42,7 @@ py_library( name = "saved_model_predictor", srcs = ["saved_model_predictor.py"], srcs_version = "PY2AND3", + visibility = ["//learning/brain/contrib/learn/tpu:__subpackages__"], deps = [ ":base_predictor", "//tensorflow/contrib/saved_model:saved_model_py", diff --git a/tensorflow/contrib/rate/BUILD b/tensorflow/contrib/rate/BUILD index c461a7145e27c4238161cec989448be807acd543..76db9aecf615d0a94f65cd7ea799db245828db1c 100644 --- a/tensorflow/contrib/rate/BUILD +++ b/tensorflow/contrib/rate/BUILD @@ -34,6 +34,11 @@ py_test( name = "rate_test", size = "small", srcs = ["rate_test.py"], + tags = [ + "manual", # TODO(b/120555555) + "no_oss", # TODO(b/120555555) + "notap", # TODO(b/120555555) + ], deps = [ ":rate", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py index 8fcd7aeef6a6964902666a4f3c17e05b0c7b52ee..f31bdbd399c9de4f2f5d557b75b1ece6d64a765e 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import lanczos from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -80,7 +81,8 @@ if __name__ == "__main__": for shape in [[4, 4], [7, 4], [5, 8]]: for orthogonalize in True, False: for steps in range(1, min(shape) + 1): - for use_static_shape in True, False: + # TF2 does not support placeholders so we skip it + for use_static_shape in set([True, tf2.enabled()]): arg_string = "%s_%s_%s_%s_staticshape_%s" % ( dtype.__name__, "_".join(map(str, shape)), orthogonalize, steps, use_static_shape) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py index 2a9100903aae5689919a6b25fcb18ff192f250b3..841a41a2339824ab8ca15f4bdd74be697cd6fe9f 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import least_squares from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -76,7 +77,8 @@ def _get_least_squares_tests(dtype_, use_static_shape_, shape_): if __name__ == "__main__": for dtype in np.float32, np.float64: for shape in [[4, 4], [8, 5], [3, 7]]: - for use_static_shape in True, False: + # TF2 does not support placeholders under eager so we skip it + for use_static_shape in set([True, tf2.enabled()]): arg_string = "%s_%s_staticshape_%s" % (dtype.__name__, "_".join(map(str, shape)), use_static_shape) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py index a0e6eb87bc06fb1303a7eb86fa6760458f20a9b9..10807f7a80617e56abeb6d13ce419a49a2269aac 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import linear_equations from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -113,7 +114,8 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_): if __name__ == "__main__": for dtype in np.float32, np.float64: for size in 1, 4, 10: - for use_static_shape in True, False: + # TF2 does not support placeholders under eager so we skip it + for use_static_shape in set([True, tf2.enabled()]): shape = [size, size] arg_string = "%s_%s_staticshape_%s" % (dtype.__name__, size, use_static_shape) diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md index caf8b6db0dc0a220d593f9c0afc9464ca51a1e05..a9c2ad78a3db409e6e8669c48c4df37c8db19c4b 100644 --- a/tensorflow/contrib/tensorrt/README.md +++ b/tensorflow/contrib/tensorrt/README.md @@ -1,8 +1,46 @@ -# Using TensorRT in TensorFlow +# Using TensorRT in TensorFlow (TF-TRT) -This module provides necessary bindings and introduces TRT_engine_op operator -that wraps a subgraph in TensorRT. This is still a work in progress but should -be useable with most common graphs. +This module provides necessary bindings and introduces `TRTEngineOp` operator +that wraps a subgraph in TensorRT. This module is under active development. + +## Installing TF-TRT + +Currently TensorFlow nightly builds include TF-TRT by default, which means you +don't need to install TF-TRT separately. You can pull the latest TF containers +from docker hub or install the latest TF pip package to get access to the latest +TF-TRT. + +If you want to use TF-TRT on NVIDIA Jetson platform, you can find the download +links for the relevant TensorFlow pip packages here: +https://docs.nvidia.com/deeplearning/dgx/index.html#installing-frameworks-for-jetson + +## Installing TensorRT + +In order to make use of TF-TRT, you will need a local installation of TensorRT. +Installation instructions for compatibility with TensorFlow are provided on the +[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. + +## Examples + +You can find example scripts for running inference on deep learning models in +this repository: https://github.com/tensorflow/tensorrt + +We have used these examples to verify the accuracy and performance of TF-TRT. +For more information see +[Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models). + +## Documentation + +[TF-TRT documentation](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) +gives an overview of the supported functionalities, provides tutorials and +verified models, explains best practices with troubleshooting guides. + +## Tests + +TF-TRT includes both Python tests and C++ unit tests. Most of Python tests are +located in the test directory and they can be executed using `bazel test` or +directly with the Python command. Most of the C++ unit tests are used to test +the conversion functions that convert each TF op to a number of TensorRT layers. ## Compilation @@ -18,12 +56,3 @@ bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/ ``` -After the installation of tensorflow package, TensorRT transformation will be -available. An example use can be found in test/test_tftrt.py script - -## Installing TensorRT 3.0.4 - -In order to make use of TensorRT integration, you will need a local installation -of TensorRT 3.0.4 from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt). -Installation instructions for compatibility with TensorFlow are provided on the -[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 3b32f72bc1f220fd6730c71e3d2b3b6b806b748e..ae211a93c3279ff1d6de2f9c9a4b849fc8cd578d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -132,6 +132,8 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { "Min", "Relu6", "Square", + "ExpandDims", + "Squeeze", }; bool is_supported_op_type = (candidate_ops.count(node->type_string()) || @@ -585,6 +587,14 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } } } + // We don't support segments with no inputs. Fall back to native TF here to + // avoid crash later. Constant folding should've folded the ops that make up + // these segments. + if (inputs.empty()) { + return tensorflow::errors::Internal( + "Segment has no inputs (possible " + "constfold failure)"); + } const bool calibrate_int8 = (info.precision_mode == INT8MODE && info.use_calibration); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index fee095668e5aef44316ff15c1d8572b2ecd960df..777a80bbc4da7a260cf85d0a7bc5ec16f4cd3cab 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -120,6 +120,15 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, return trt_dims; } +Status TensorShapeArrayToTrtDims(const std::vector& shape, + nvinfer1::Dims* out, + bool ignore_first_dim = false) { + PartialTensorShape tensor_shape; + TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape)); + *out = TensorShapeToTrtDims(tensor_shape, ignore_first_dim); + return tensorflow::Status::OK(); +} + void GetOutputProperties(const grappler::GraphProperties& graph_properties, const Node* node, const int out_port, PartialTensorShape* shape, @@ -1880,6 +1889,133 @@ tensorflow::Status ConvertReshape(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertExpandDims(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 2) { + return tensorflow::errors::InvalidArgument( + "Two inputs expected for ExpandDims, at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + "ExpandDims expects tensor for input, at ", node_def.name()); + } + if (!inputs.at(1).is_weights()) { + return tensorflow::errors::InvalidArgument( + "ExpandDims expects weights for axis, at ", node_def.name()); + } + // Get input shape as vector. + TRT_TensorOrWeights input_tensor = inputs.at(0); + const nvinfer1::Dims dims = input_tensor.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dim back. + input_dims.insert(input_dims.begin(), -1); + const int input_rank = input_dims.size(); + // Get axis to expand on. + TRT_ShapedWeights weights = inputs.at(1).weights(); + if (weights.count() != 1) { + return tensorflow::errors::InvalidArgument( + "ExpandDims axis must be a scalar, at ", node_def.name()); + } + const int* weights_ptr = + static_cast(const_cast(weights.GetValues())); + int axis = weights_ptr[0]; + // Make sure axis is valid. + if ((axis < (-input_rank - 1)) || (axis > input_rank)) { + return tensorflow::errors::InvalidArgument( + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at ", + node_def.name()); + } + // Convert negative axis to corresponding positive axis. + if (axis < 0) axis += input_rank + 1; + if (axis == 0) { + return tensorflow::errors::Unimplemented( + "Modifying batch dimension is not supported for ExpandDims, at ", + node_def.name()); + } + if (params->validation_only) return Status::OK(); + + // ExpandDims: Insert new dim of size 1. + input_dims.insert(input_dims.begin() + axis, 1); + // Reshape tensor. + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input_tensor, new_dims, &output_tensor)); + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(output_tensor))); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertSqueeze(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 1) { + return tensorflow::errors::InvalidArgument( + "One input expected for Squeeze, at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + "Squeeze expects tensor for input, at ", node_def.name()); + } + // Get input shape. + TRT_TensorOrWeights input_tensor = inputs.at(0); + const nvinfer1::Dims dims = input_tensor.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dim back. + input_dims.insert(input_dims.begin(), -1); + const int input_rank = input_dims.size(); + // Mark axes to remove by setting them to 0. + TFAttrs attrs(node_def); + auto squeeze_dims = attrs.get>("squeeze_dims"); + if (squeeze_dims.size() == 0) { + return tensorflow::errors::Unimplemented( + "Squeeze is only implemented for explicit dims, at ", node_def.name()); + } + for (int axis : squeeze_dims) { + // Make sure axis is valid. + if ((axis < -input_rank) || (axis >= input_rank)) { + return tensorflow::errors::InvalidArgument( + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at ", + node_def.name()); + } + // Convert negative axis to corresponding positive axis. + if (axis < 0) axis += input_rank; + // Don't squeeze batch dim. + if (axis == 0) { + return tensorflow::errors::Unimplemented( + "Cannot squeeze batch dimension, at ", node_def.name()); + } + // Make sure target dimension is size 1. + if (input_dims[axis] != 1) { + return tensorflow::errors::InvalidArgument( + "Cannot squeeze a dimension which isn't size 1, at ", + node_def.name()); + } + // Mark dim for removal by setting to 0. + input_dims[axis] = 0; + } + if (params->validation_only) return Status::OK(); + + // Remove all dims which are equal to 0. + input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), + input_dims.end()); + // Reshape tensor. + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input_tensor, new_dims, &output_tensor)); + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(output_tensor))); + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertConv2D(OpConverterParams* params) { return ConvertConv2DHelper(params, ConvolutionType::DEFAULT); } @@ -3156,6 +3292,8 @@ static void RegisterValidatableOpConverters( (*registration)["MatMul"] = ConvertMatMul; (*registration)["Relu6"] = ConvertRelu6; (*registration)["Square"] = ConvertSquare; + (*registration)["ExpandDims"] = ConvertExpandDims; + (*registration)["Squeeze"] = ConvertSqueeze; for (auto quantization_op_type : {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index 443033379f0d6554784d44412a02aa8cb035ab08..c37a43dd5def9daf3c5d70720c6db2aab20db077 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -2113,6 +2113,242 @@ TEST_F(OpConverterTest, ConvertActivation) { } } +TEST_F(OpConverterTest, ConvertExpandDims) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_expanddims", "ExpandDims", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Two inputs expected for ExpandDims, at my_expanddims"); + } + + // Get the NodeDef for ExpandDims. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto expanddims = + ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights); + const NodeDef& node_def = expanddims.operation.node()->def(); + + { + // Input is weights, should fail. + Reset(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("weights", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "ExpandDims expects tensor for input, at my_expanddims"); + } + { + // Axis is a tensor, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights", {3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "ExpandDims expects weights for axis, at my_expanddims"); + } + { + // Add dim at batch dimension, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {1}, {0}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Modifying batch dimension is not supported for ExpandDims, at " + "my_expanddims"); + } + { + // Add dim at batch dimension via negative axis, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {-5}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Modifying batch dimension is not supported for ExpandDims, at " + "my_expanddims"); + } + { + // Axis > rank(input), should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {5}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at my_expanddims"); + } + { + // Axis < -rank(input)-1, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {-6}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at my_expanddims"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, int axis, + const std::vector& expected_output_dims) + : input_dims(input_dims), + axis(axis), + expected_output_dims(expected_output_dims) {} + std::vector input_dims; + int axis; + std::vector expected_output_dims; + }; + + // Ok. + const int kExpandDimsOKCases = 8; + TestParams ok_params[kExpandDimsOKCases] = { + TestParams{{2, 3}, 1, {1, 2, 3}}, TestParams{{2, 3}, -3, {1, 2, 3}}, + TestParams{{2, 3}, 3, {2, 3, 1}}, TestParams{{2, 3}, -1, {2, 3, 1}}, + TestParams{{2, 3}, 2, {2, 1, 3}}, TestParams{{2, 3}, -2, {2, 1, 3}}, + TestParams{{6}, 1, {1, 6}}, TestParams{{6}, -1, {6, 1}}, + }; + for (int i = 0; i < kExpandDimsOKCases; ++i) { + Reset(); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", {1}, {ok_params[i].axis}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_expanddims", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + std::vector output_data(6); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_expanddims", + &output_data); + EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + } +} + +TEST_F(OpConverterTest, ConvertSqueeze) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_squeeze", "Squeeze", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "One input expected for Squeeze, at my_squeeze"); + } + { + // No attrs, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input); + const NodeDef& node_def = squeeze.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Squeeze is only implemented for explicit dims, at my_squeeze"); + } + + // Get the NodeDef for Squeeze. + auto get_squeeze_nodedef = [](std::vector axis) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + ops::Squeeze::Attrs squeeze_attrs; + squeeze_attrs.axis_ = gtl::ArraySlice(axis); + auto squeeze = + ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); + return squeeze.operation.node()->def(); + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({0}); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Squeeze expects tensor for input, at my_squeeze"); + } + { + // Squeeze batch dim, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({0}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Cannot squeeze batch dimension, at my_squeeze"); + } + { + // Squeeze batch dim via negative axis, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({-4}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Cannot squeeze batch dimension, at my_squeeze"); + } + { + // Squeeze >= rank(input), should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({4}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at my_squeeze"); + } + { + // Squeeze < -rank(input), should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({-5}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at my_squeeze"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, const std::vector& axis, + const std::vector& expected_output_dims) + : input_dims(input_dims), + axis(axis), + expected_output_dims(expected_output_dims) {} + std::vector input_dims; + std::vector axis; + std::vector expected_output_dims; + }; + + // Ok. + const int kSqueezeOKCases = 10; + TestParams ok_params[kSqueezeOKCases] = { + TestParams{{1, 2, 3}, {1}, {2, 3}}, + TestParams{{1, 2, 3}, {-3}, {2, 3}}, + TestParams{{2, 3, 1}, {3}, {2, 3}}, + TestParams{{2, 3, 1}, {-1}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {1, 3, 5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {3, 1, 5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {-1, -3, -5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {1, -3, 5}, {2, 3}}, + TestParams{{1, 6}, {1}, {6}}, + TestParams{{6, 1}, {2}, {6}}, + }; + for (int i = 0; i < kSqueezeOKCases; ++i) { + Reset(); + NodeDef node_def = get_squeeze_nodedef(ok_params[i].axis); + AddTestTensor("input", ok_params[i].input_dims); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_squeeze", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + std::vector output_data(6); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_squeeze", + &output_data); + EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + } +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py index 0cd733dca13462ac8f4478544005ae4000f711f1..563232fc12675d9e1b32b7ab461591af57beadb9 100644 --- a/tensorflow/contrib/tensorrt/test/rank_two_test.py +++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py @@ -51,8 +51,10 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): c = constant_op.constant(3.0, name="c%d_3" % i) q = math_ops.add(q, c, name="add%d_3" % i) if i == 0: + axis = constant_op.constant(-1, dtype=dtypes.int32, name="axis") for j in range(2): - q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j)) + q = array_ops.expand_dims(q, axis, name="expand%d_%d" % (i, j)) + q = self.trt_incompatible_op(q) q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i) outputs.append(q) # Combine both paths @@ -70,7 +72,7 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): return { "TRTEngineOp_0": [ "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1", - "abs0_2" + "abs0_2", "expand0_0", "expand0_1", "axis" ], "TRTEngineOp_1": [ "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3", diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py index 9fc50e05952abd335e196dce8fc8a81056d7007d..b6e5e32db1236684a06c2d44298b9a3d39667152 100644 --- a/tensorflow/contrib/tensorrt/test/unary_test.py +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -106,10 +106,7 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return [ - "TRTEngineOp_0", "TRTEngineOp_1", "TRTEngineOp_2", "TRTEngineOp_3", - "TRTEngineOp_4" - ] + return ["TRTEngineOp_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md index b29d1acacf17b57549558be45c853566817c1729..f40e76f554e8815aac96344d8cb0b911bafdd712 100644 --- a/tensorflow/contrib/tfprof/README.md +++ b/tensorflow/contrib/tfprof/README.md @@ -1,7 +1,5 @@ # tfprof: TensorFlow Profiler and Beyond -

Please use `tf.profiler.xxx` instead of `tf.contrib.tfprof.xxx`

-

Full Document in tensorflow/core/profiler/README.md

diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 05d2ebd2e8a3292a95df0e2f976df0e2871063f8..4bf3a0463d9046eea2f60e9154fca1357e728215 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -79,6 +79,7 @@ py_library( "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:summary_ops_v2", diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index b4b06a40a2c8aaa97ff82baf93c8f2d55a587e37..ef35e84ba5205fb76e5afe77e670d87197ca8405 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -98,7 +98,7 @@ Status DumpOpProfileToLogDirectory(StringPiece run_dir, if (!status.ok()) { return errors::Internal( "Failed to convert op profile to json. Skipping... ", - string(status.error_message())); + string(status.message())); } TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, json)); if (os) { diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 63641e00c5dbf4b4e635ecfea8bef98c7d0b7075..a081c4354a779d37140338793e66844c3fcf7a12 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -90,12 +90,12 @@ def main(unused_argv=None): tf_version = tf.__version__ print('TensorFlow version %s detected' % tf_version) - if FLAGS.service_addr is None and FLAGS.tpu is None: + if not FLAGS.service_addr and not FLAGS.tpu: sys.exit('You must specify either --service_addr or --tpu.') tpu_cluster_resolver = None - if FLAGS.service_addr is not None: - if FLAGS.tpu is not None: + if FLAGS.service_addr: + if FLAGS.tpu: tf.logging.warn('Both --service_addr and --tpu are set. Ignoring ' '--tpu and using --service_addr.') service_addr = FLAGS.service_addr diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py index b58d05eac56f3586e183333f7c1a3867ee57456c..52d87b800401c3e584da9843916cfc7a767c082a 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -70,7 +70,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -94,7 +94,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -121,7 +121,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -154,7 +154,7 @@ class DatasetsTest(test.TestCase): os.path.join(self.get_temp_dir(), 'fixed_length*'), filetype=FixedLengthFile) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -177,7 +177,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( dataset_ops.Dataset.range(10), filetype=gen_dataset) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index cf3b2e68e940652220983c98e3a0acb68cf88d89..4ce194590342555a7c4e9e119bf51e516a37a715 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -133,7 +133,7 @@ def _tpu_session_context(): An error occurred connecting or initializing your TPU. The session has been reset. re-run keras_to_tpu_model to create a new session. -""" + e) +""" + str(e)) def setup_tpu_session(cluster_resolver): @@ -729,7 +729,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): dummy_x_shape[0] *= tpu_assignment.num_towers dummy_y_shape = dataset.output_shapes[1].as_list() dummy_y_shape[0] *= tpu_assignment.num_towers - self._iterator = dataset.make_initializable_iterator() + self._iterator = dataset_ops.make_initializable_iterator(dataset) K.get_session().run(self._iterator.initializer) self._get_next_ops = [] @@ -1676,14 +1676,10 @@ class KerasTPUModel(models.Model): callbacks, self, do_validation=do_validation, - val_inputs=val_inputs, - val_targets=val_targets, - val_sample_weights=val_sample_weights, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, samples=num_training_samples, - validation_steps=validation_steps, verbose=verbose, count_mode=count_mode) diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index a95275487899c4770ef99b620a7671eec2bb81eb..3e463823c820a3ef8628324f77e1a9caf8d385d5 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -43,12 +43,19 @@ class CoordinatorShutdownException(Exception): pass +def _clone_session(session, graph=None): + return session_lib.Session( + target=session.sess_str, + config=session._config, # pylint: disable=protected-access + graph=graph if graph else session.graph) + + def _make_heartbeat_op(session, device, request_ph): """Return a heartbeat op or None if heartbeats are not supported by device.""" try: # Test if we can connect in a isolated graph + session with ops.Graph().as_default(): - with session_lib.Session(target=session.sess_str) as temp_session: + with _clone_session(session) as temp_session: with ops.device(device): heartbeat_op = tpu_ops.worker_heartbeat('') options = config_pb2.RunOptions(timeout_in_ms=5000) @@ -220,6 +227,7 @@ class WatchdogManager(threading.Thread): self.ping_interval = ping_interval self.shutdown_timeout = shutdown_timeout self.daemon = True + self._config = session._config # pylint: disable=protected-access self._target = session.sess_str self._running = False self._devices = devices @@ -234,6 +242,7 @@ class WatchdogManager(threading.Thread): self._session = session_lib.Session( target=self._target, graph=self._graph, + config=self._config, ) if self._devices is None: @@ -334,8 +343,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): with self._graph.as_default(): logging.info('Installing graceful shutdown hook.') - self._session = session_lib.Session( - target=training_session.sess_str, graph=self._graph) + self._session = _clone_session(training_session, self._graph) self._workers = WorkerHeartbeatManager.from_devices( self._session, all_worker_devices(self._session)) self._heartbeat_supported = self._workers.num_workers() > 0 diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 7171587ff7298982423a5046d85d1970a4d6b1cb..84816d70d081854af9252e771df49957ec717db3 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -45,6 +45,7 @@ from tensorflow.contrib.training.python.training import hparam from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest as data_nest from tensorflow.python.estimator import estimator as estimator_lib @@ -412,12 +413,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=True, - rendezvous=None): + rendezvous=None, + master=None, + session_config=None): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops self._rendezvous = rendezvous - + self._master = master + self._session_config = session_config self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) @@ -429,11 +433,10 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() + self._init_ops = [] if self._should_initialize_tpu: - self._init_ops = [tpu.initialize_system(job=self._master_job)] self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] else: - self._init_ops = [] self._finalize_ops = [] summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() @@ -475,11 +478,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): return _OpQueueContext(name=name, target=target, args=args) def after_create_session(self, session, coord): - logging.info('Init TPU system') - start = time.time() + if self._should_initialize_tpu: + logging.info('Init TPU system') + start = time.time() + with ops.Graph().as_default(): + with tf_session.Session( + self._master, config=self._session_config) as sess: + sess.run(tpu.initialize_system(job=self._master_job)) + logging.info('Initialized TPU in %d seconds', time.time() - start) + session.run(self._init_ops, options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) - logging.info('Initialized TPU in %d seconds', time.time() - start) self._infeed_controller = self._create_infeed_controller( name='InfeedController', target=self._run_infeed, args=(session,)) @@ -521,13 +530,16 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None): + def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None, + master=None, session_config=None): super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False, - rendezvous=rendezvous) + rendezvous=rendezvous, + master=master, + session_config=session_config) def _create_infeed_controller(self, name, target, args): return _OpSignalOnceQueueContext(name=name, target=target, args=args) @@ -2561,6 +2573,8 @@ class TPUEstimator(estimator_lib.Estimator): run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), rendezvous=self._rendezvous[mode], + master=self._config.master, + session_config=self._session_config, ), InstallSignalHandlerHook() ]) @@ -2663,8 +2677,10 @@ class TPUEstimator(estimator_lib.Estimator): eval_update_ops + host_ops, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode]), - ] + input_hooks + rendezvous=self._rendezvous[mode], + master=self._config.evaluation_master, + session_config=self._session_config, + )] + input_hooks if eval_hooks: hooks.extend(eval_hooks) @@ -2735,7 +2751,9 @@ class TPUEstimator(estimator_lib.Estimator): hooks = [ _StoppingPredictHook(scalar_stopping_signal), TPUInfeedOutfeedSessionHookForPrediction( - ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]), + ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode], + master=self._config.master, + session_config=self._session_config), ] + input_hooks if prediction_hooks: @@ -3081,7 +3099,7 @@ class _Inputs(object): The initializer must be run before calling `features_and_labels`. """ - self._iterator = self._dataset.make_initializable_iterator() + self._iterator = dataset_ops.make_initializable_iterator(self._dataset) return self._iterator.initializer def features_and_labels(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py index 55235556de0214a8e04fb85469cd1d8e4656fb56..e3ea983abfd24d03c964fbc647b56262e15e0a96 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py @@ -21,8 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.tpu.python.tpu import tpu_estimator -from tensorflow.python import data as dataset_lib from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -34,10 +34,10 @@ def make_input_fn(num_samples): def input_fn(params): batch_size = params['batch_size'] - da1 = dataset_lib.Dataset.from_tensor_slices(a) - da2 = dataset_lib.Dataset.from_tensor_slices(b) + da1 = dataset_ops.Dataset.from_tensor_slices(a) + da2 = dataset_ops.Dataset.from_tensor_slices(b) - dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset_ops.Dataset.zip((da1, da2)) dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb}) dataset = dataset.batch(batch_size) return dataset @@ -50,10 +50,10 @@ def make_input_fn_with_labels(num_samples): def input_fn(params): batch_size = params['batch_size'] - da1 = dataset_lib.Dataset.from_tensor_slices(a) - da2 = dataset_lib.Dataset.from_tensor_slices(b) + da1 = dataset_ops.Dataset.from_tensor_slices(a) + da2 = dataset_ops.Dataset.from_tensor_slices(b) - dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset_ops.Dataset.zip((da1, da2)) dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb)) dataset = dataset.batch(batch_size) return dataset @@ -71,7 +71,7 @@ class TPUEstimatorStoppingSignalsTest(test.TestCase): with ops.Graph().as_default(): dataset = input_fn(params) - features = dataset_lib.make_one_shot_iterator(dataset).get_next() + features = dataset_ops.make_one_shot_iterator(dataset).get_next() # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. self.assertIsNone(features['a'].shape.as_list()[0]) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index ec682e5829c4df536a043334b74200f0b6259df3..d66ecfcf4a56b8da1c2d2f518bebe4baa76b315e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -52,6 +52,7 @@ def _query_tpu_system_metadata(master_address, cluster_def=None, devices = [] device_dict = collections.defaultdict(list) + # TODO(b/120564445): Replace with standard library for retries. retry_count = 1 while True: logging.info('Querying Tensorflow master (%s) for TPU system metadata.', diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c268605711fb73f37773ce7b4181bf17f2a3a4fa..66714235b535c14a8f13c40bb2a4df8d7494dc05 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -492,7 +492,10 @@ cc_library( ":platform_env_internal_hdrs", ], copts = tf_copts(), - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [ + "//tensorflow/c:__subpackages__", + "//tensorflow/core:__subpackages__", + ], deps = [ ":error_codes_proto_cc", ":lib", @@ -1679,6 +1682,9 @@ filegroup( # operators, use :android_tensorflow_lib if you want full operator # support. # +# If you just need TensorFlow types, e.g. Tensors, use +# :android_tensorflow_lib_lite_no_runtime. +# # Compiles to a trivial library on non-Android to prevent irrelevant # build errors. If not building this as part of an android_binary, # a command such as the following must be used: @@ -1689,7 +1695,33 @@ filegroup( cc_library( name = "android_tensorflow_lib_lite", srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None), + copts = tf_copts(android_optimization_level_override = None) + [ + "-DSUPPORT_SELECTIVE_REGISTRATION", + ], + linkopts = ["-lz"], + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], + deps = [ + ":mobile_additional_lib_deps", + ":protos_all_cc_impl", + ":stats_calculator_portable", + "//third_party/eigen3", + "@double_conversion//:double-conversion", + "@nsync//:nsync_cpp", + "@protobuf_archive//:protobuf", + ], + alwayslink = 1, +) + +cc_library( + name = "android_tensorflow_lib_lite_nortti", + srcs = if_android(["//tensorflow/core:android_srcs"]), + copts = tf_copts(android_optimization_level_override = None) + [ + "-DSUPPORT_SELECTIVE_REGISTRATION", + ] + tf_opts_nortti_if_android(), linkopts = ["-lz"], tags = [ "manual", @@ -1801,52 +1833,21 @@ cc_library( # Does not contain operators. In contrast to android_tensorflow_lib_lite, # this links in framework support for all types, relying on selective # registration of ops to prune code size. -cc_library( +# +# TODO(gonnet): Move all users of these aliases to the corresponding +# :android_tensorflow_lib_lite* targets and remove. +alias( name = "android_tensorflow_lib_selective_registration", - srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None) + [ - "-DSUPPORT_SELECTIVE_REGISTRATION", - ], - linkopts = if_android(["-lz"]), - tags = [ - "manual", - "notap", - ], + actual = ":android_tensorflow_lib_lite", visibility = ["//visibility:public"], - deps = [ - ":protos_all_cc_impl", - "//third_party/eigen3", - "@com_google_absl//absl/container:flat_hash_set", - "@double_conversion//:double-conversion", - "@nsync//:nsync_cpp", - "@protobuf_archive//:protobuf", - ], - alwayslink = 1, ) # Android library for use with the SELECTIVE_REGISTRATION feature with # no proto_rtti. -cc_library( +alias( name = "android_tensorflow_lib_selective_registration_nortti", - srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ - "-DSUPPORT_SELECTIVE_REGISTRATION", - ], - linkopts = if_android(["-lz"]), - tags = [ - "manual", - "notap", - ], + actual = ":android_tensorflow_lib_lite_nortti", visibility = ["//visibility:public"], - deps = [ - ":protos_all_cc_impl", - "//third_party/eigen3", - "@com_google_absl//absl/container:flat_hash_set", - "@double_conversion//:double-conversion", - "@nsync//:nsync_cpp", - "@protobuf_archive//:protobuf", - ], - alwayslink = 1, ) filegroup( @@ -2087,9 +2088,7 @@ tf_proto_library_cc( srcs = ["protobuf/master.proto"], cc_api_version = 2, protodeps = tf_additional_all_protos(), - visibility = [ - "//tensorflow:internal", - ], + visibility = ["//tensorflow:internal"], ) tf_proto_library_cc( diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index d38a8424eb13009fbf84d7511fb1325085d8b809..7405e2ace72d1c08cf87cc0040e617379e18149b 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalDatasetCardinality.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalDatasetCardinality.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ac014bcc5e6ae48cdecd6acefca267da3f2fe4f1 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalDatasetCardinality.pbtxt @@ -0,0 +1,21 @@ +op { + graph_op_name: "ExperimentalDatasetCardinality" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <