diff --git a/.gitignore b/.gitignore index 09734fe4974935956fd599f7f86cd5c4d195d5e2..9ae0d9c96f188bc6357832f22b4125694302b104 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,8 @@ cmake_build/ .idea/** /build/ /tensorflow/core/util/version_info.cc +/tensorflow/python/framework/fast_tensor_util.cpp +Pods +Podfile.lock +*.pbxproj +*.xcworkspacedata diff --git a/README.md b/README.md index 6339c57c95032d490c8acf02f21967d4ee35ddb9..24bbb6cec10e16c7b6ae37b7cf8b6f90ebe5e5dd 100644 --- a/README.md +++ b/README.md @@ -38,10 +38,11 @@ People who are a little more adventurous can also try our nightly binaries: **Nightly pip packages** * We are pleased to announce that TensorFlow now offers nightly pip packages -under the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) project on pypi. -Simply run `pip install tf-nightly` in a clean environment to install the nightly -tensorflow build. We currently only support CPU packages on Linux, Mac, and Windows. -GPU packages on all platforms will arrive soon! +under the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) and +[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on pypi. +Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean +environment to install the nightly TensorFlow build. We support CPU and GPU +packages on Linux, Mac, and Windows. **Individual whl files** diff --git a/RELEASE.md b/RELEASE.md index 634b31b82b9758d78973d655b7f9c5c1a7fe214e..4a33bce8b256136e9310336b2078d76b7d46fd9b 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,9 +1,51 @@ # Release 1.4.0 ## Major Features And Improvements +* `tf.keras` is now part of the core TensorFlow API. +* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of + the core TensorFlow API. + * The API is now subject to backwards compatibility guarantees. + * For a guide to migrating from the `tf.contrib.data` API, see the + [README](https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/contrib/data/README.md). + * Major new features include `Dataset.from_generator()` (for building an input + pipeline from a Python generator), and the `Dataset.apply()` method for + applying custom transformation functions. + * Several custom transformation functions have been added, including + `tf.contrib.data.batch_and_drop_remainder()` and + `tf.contrib.data.sloppy_interleave()`. +* Add `train_and_evaluate` for simple distributed `Estimator` training. +* Add `tf.spectral.dct` for computing the DCT-II. +* Add Mel-Frequency Cepstral Coefficient support to `tf.contrib.signal` + (with GPU and gradient support). +* Add a self-check on `import tensorflow` for Windows DLL issues. +* Add NCHW support to `tf.depth_to_space` on GPU. +* SinhArcsinh (scalar) distribution added to `contrib.distributions`. +* Make `GANEstimator` opensource. +* `Estimator.export_savedmodel()` now includes all valid serving signatures + that can be constructed from the Serving Input Receiver and all available + ExportOutputs. For instance, a classifier may provide regression- and + prediction-flavored outputs, in addition to the classification-flavored one. + Building signatures from these allows TF Serving to honor requests using the + different APIs (Classify, Regress, and Predict). Furthermore, + `serving_input_receiver_fn()` may now specify alternative subsets of nodes + that may act as inputs. This allows, for instance, producing a prediction + signature for a classifier that accepts raw `Tensors` instead of a serialized + `tf.Example`. +* Add `tf.contrib.bayesflow.hmc`. +* Add `tf.contrib.distributions.MixtureSameFamily`. +* Make `Dataset.shuffle()` always reshuffles after each iteration by default. +* Add `tf.contrib.bayesflow.metropolis_hastings`. +* Add `log_rate` parameter to `tf.contrib.distributions.Poisson`. +* Extend `tf.contrib.distributions.bijector` API to handle some non-injective + transforms. * Java: - * Generics (e.g., `Tensor`) for improved type-safety (courtesy @andrewcmyers). + * Generics (e.g., `Tensor`) for improved type-safety + (courtesy @andrewcmyers). * Support for multi-dimensional string tensors. + * Support loading of custom operations (e.g. many in `tf.contrib`) on Linux + and OS X +* All our prebuilt binaries have been built with CUDA 8 and cuDNN 6. + We anticipate releasing TensorFlow 1.5 with CUDA 9 and cuDNN 7. ## Bug Fixes and Other Changes * `tf.nn.rnn_cell.DropoutWrapper` is now more careful about dropping out LSTM @@ -15,6 +57,57 @@ * Removed `tf.contrib.training.python_input`. The same behavior, in a more flexible and reproducible package, is available via the new `tf.contrib.data.Dataset.from_generator` method! +* Fix `tf.contrib.distributions.Affine` incorrectly computing log-det-jacobian. +* Fix `tf.random_gamma` incorrectly handling non-batch, scalar draws. +* Resolved a race condition in TensorForest TreePredictionsV4Op. +* Google Cloud Storage file system and Hadoop file system support are now + default build options. +* Custom op libraries must link against libtensorflow_framework.so + (installed at `tf.sysconfig.get_lib()`). + +## Breaking Changes to the API +* The signature of the `tf.contrib.data.rejection_resample()` function has been + changed. It now returns a function that can be used as an argument to + `Dataset.apply()`. +* Remove `tf.contrib.data.Iterator.from_dataset()` method. Use + `Dataset.make_initializable_iterator()` instead. +* Remove seldom used and unnecessary `tf.contrib.data.Iterator.dispose_op()`. +* Reorder some TFGAN loss functions in a non-backwards compatible way. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Abdullah Alrasheed, abenmao, Adam Salvail, Aditya Dhulipala, Ag Ramesh, +Akimasa Kimura, Alan Du, Alan Yee, Alexander, Amit Kushwaha, Amy, Andrei Costinescu, +Andrei Nigmatulin, Andrew Erlichson, Andrew Myers, Andrew Stepanov, Androbin, AngryPowman, +Anish Shah, Anton Daitche, Artsiom Chapialiou, asdf2014, Aseem Raj Baranwal, Ash Hall, +Bart Kiers, Batchu Venkat Vishal, ben, Ben Barsdell, Bill Piel, Carl Thomé, Catalin Voss, +Changming Sun, Chengzhi Chen, Chi Zeng, Chris Antaki, Chris Donahue, Chris Oelmueller, +Chris Tava, Clayne Robison, Codrut, Courtial Florian, Dalmo Cirne, Dan J, Darren Garvey, +David Kristoffersson, David Norman, David RöThlisberger, DavidNorman, Dhruv, DimanNe, +Dorokhov, Duncan Mac-Vicar P, EdwardDixon, EMCP, error.d, FAIJUL, Fan Xia, +Francois Xavier, Fred Reiss, Freedom" Koan-Sin Tan, Fritz Obermeyer, Gao, Xiang, +Guenther Schmuelling, Guo Yejun (郭叶军), Hans Gaiser, HectorSVC, Hyungsuk Yoon, +James Pruegsanusak, Jay Young, Jean Wanka, Jeff Carpenter, Jeremy Rutman, Jeroen BéDorf, +Jett Jones, Jimmy Jia, jinghuangintel, jinze1994, JKurland, Joel Hestness, joetoth, +John B Nelson, John Impallomeni, John Lawson, Jonas, Jonathan Dekhtiar, joshkyh, Jun Luan, +Jun Mei, Kai Sasaki, Karl Lessard, karl@kubx.ca, Kb Sriram, Kenichi Ueno, Kevin Slagle, +Kongsea, Lakshay Garg, lhlmgr, Lin Min, liu.guangcong, Loki Der Quaeler, Louie Helm, +lucasmoura, Luke Iwanski, Lyndon White, Mahmoud Abuzaina, Marcel Puyat, Mark Aaron Shirley, +Michele Colombo, MtDersvan, Namrata-Ibm, Nathan Luehr, Naurril, Nayana Thorat, Nicolas Lopez, +Niranjan Hasabnis, Nolan Liu, Nouce, Oliver Hennigh, osdamv, Patrik Erdes, +Patryk Chrabaszcz, Pavel Christof, Penghao Cen, postBG, Qingqing Cao, Qingying Chen, qjivy, +Raphael, Rasmi, raymondxyang, Renze Yu, resec, Roffel, Ruben Vereecken, Ryohei Kuroki, +sandipmgiri, Santiago Castro, Scott Kirkland, Sean Vig, Sebastian Raschka, Sebastian Weiss, +Sergey Kolesnikov, Sergii Khomenko, Shahid, Shivam Kotwalia, Stuart Berg, Sumit Gouthaman, +superzerg, Sven Mayer, tetris, Ti Zhou, Tiago Freitas Pereira, Tian Jin, Tomoaki Oiki, +Vaibhav Sood, vfdev, Vivek Rane, Vladimir Moskva, wangqr, Weber Xie, Will Frey, +Yan Facai (颜发才), yanivbl6, Yaroslav Bulatov, Yixing Lao, Yong Tang, youkaichao, +Yuan (Terry) Tang, Yue Zhang, Yuxin Wu, Ziming Dong, ZxYuan, 黄璞 + +We are also grateful to all who filed issues or helped resolve them, asked and +answered questions, and were part of inspiring discussions. # Release 1.3.0 diff --git a/WORKSPACE b/WORKSPACE index 1bf1069f8801c9d135d77c871520ff733b7713e9..b40913801ba8e3c8ee73f7ba69540b520ad698a6 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -5,7 +5,7 @@ http_archive( sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257", strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1", urls = [ - "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", "https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28 ], ) diff --git a/configure.py b/configure.py index 9ca614f8f9b8f7e30c3e92e84721886d08329d01..425eae676cb679f1ac5d91ba7bd11645d9471923 100644 --- a/configure.py +++ b/configure.py @@ -963,6 +963,19 @@ def set_monolithic(): write_to_bazelrc('build --define framework_shared_object=true') +def create_android_bazelrc_configs(): + # Flags for --config=android + write_to_bazelrc('build:android --crosstool_top=//external:android/crosstool') + write_to_bazelrc( + 'build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain') + # Flags for --config=android_arm + write_to_bazelrc('build:android_arm --config=android') + write_to_bazelrc('build:android_arm --cpu=armeabi-v7a') + # Flags for --config=android_arm64 + write_to_bazelrc('build:android_arm64 --config=android') + write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a') + + def main(): # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. @@ -976,6 +989,7 @@ def main(): run_gen_git_source(environ_cp) if is_windows(): + environ_cp['TF_NEED_S3'] = '0' environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0' @@ -988,9 +1002,11 @@ def main(): set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', 'with_jemalloc', True) set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform', - 'with_gcp_support', False, 'gcp') + 'with_gcp_support', True, 'gcp') set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System', - 'with_hdfs_support', False, 'hdfs') + 'with_hdfs_support', True, 'hdfs') + set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', + 'with_s3_support', True, 's3') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', False, 'xla') set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', @@ -1030,7 +1046,7 @@ def main(): set_cc_opt_flags(environ_cp) set_mkl() set_monolithic() - + create_android_bazelrc_configs() if __name__ == '__main__': main() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 0c629dabd833804499927e3329d048dbc3dcedad..e8fd66fe61add61c074644c508a74c0f67d31b35 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -120,6 +120,15 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "ios_x86_64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_x86_64", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "linux_x86_64", values = {"cpu": "k8"}, @@ -185,6 +194,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_s3_support", + values = {"define": "with_s3_support=true"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_xla_support", values = {"define": "with_xla_support=true"}, @@ -276,8 +291,8 @@ config_setting( package_group( name = "internal", packages = [ - "//learning/protonn/llgtm/...", "//tensorflow/...", + "//tensorflow_fold/llgtm/...", ], ) @@ -316,6 +331,7 @@ filegroup( "//tensorflow/compiler/jit/kernels:all_files", "//tensorflow/compiler/jit/legacy_flags:all_files", "//tensorflow/compiler/jit/ops:all_files", + "//tensorflow/compiler/plugin:all_files", "//tensorflow/compiler/tests:all_files", "//tensorflow/compiler/tf2xla:all_files", "//tensorflow/compiler/tf2xla/cc:all_files", @@ -333,6 +349,7 @@ filegroup( "//tensorflow/compiler/xla/service/llvm_ir:all_files", "//tensorflow/compiler/xla/tests:all_files", "//tensorflow/compiler/xla/tools:all_files", + "//tensorflow/compiler/xla/tools/parser:all_files", "//tensorflow/contrib:all_files", "//tensorflow/contrib/all_reduce:all_files", "//tensorflow/contrib/android:all_files", @@ -392,6 +409,7 @@ filegroup( "//tensorflow/contrib/linear_optimizer:all_files", "//tensorflow/contrib/lookup:all_files", "//tensorflow/contrib/losses:all_files", + "//tensorflow/contrib/makefile:all_files", "//tensorflow/contrib/meta_graph_transform:all_files", "//tensorflow/contrib/metrics:all_files", "//tensorflow/contrib/mpi_collectives:all_files", @@ -406,7 +424,6 @@ filegroup( "//tensorflow/contrib/remote_fused_graph/pylib:all_files", "//tensorflow/contrib/resampler:all_files", "//tensorflow/contrib/rnn:all_files", - "//tensorflow/contrib/s3:all_files", "//tensorflow/contrib/saved_model:all_files", "//tensorflow/contrib/saved_model/cc/saved_model:all_files", "//tensorflow/contrib/seq2seq:all_files", @@ -428,6 +445,7 @@ filegroup( "//tensorflow/contrib/tensor_forest/kernels/v4:all_files", "//tensorflow/contrib/tensor_forest/proto:all_files", "//tensorflow/contrib/tensorboard:all_files", + "//tensorflow/contrib/tensorboard/db:all_files", "//tensorflow/contrib/testing:all_files", "//tensorflow/contrib/text:all_files", "//tensorflow/contrib/tfprof:all_files", @@ -440,7 +458,6 @@ filegroup( "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", "//tensorflow/contrib/verbs:all_files", - "//tensorflow/contrib/xla_tf_graph:all_files", "//tensorflow/core:all_files", "//tensorflow/core/debug:all_files", "//tensorflow/core/distributed_runtime:all_files", @@ -455,10 +472,12 @@ filegroup( "//tensorflow/core/kernels/fuzzing:all_files", "//tensorflow/core/kernels/hexagon:all_files", "//tensorflow/core/kernels/neon:all_files", + "//tensorflow/core/lib/db:all_files", "//tensorflow/core/ops/compat:all_files", "//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/default/build_config:all_files", "//tensorflow/core/platform/hadoop:all_files", + "//tensorflow/core/platform/s3:all_files", "//tensorflow/core/profiler:all_files", "//tensorflow/core/profiler/internal:all_files", "//tensorflow/core/profiler/internal/advisor:all_files", @@ -492,7 +511,10 @@ filegroup( "//tensorflow/python/keras:all_files", "//tensorflow/python/kernel_tests:all_files", "//tensorflow/python/kernel_tests/distributions:all_files", + "//tensorflow/python/kernel_tests/linalg:all_files", + "//tensorflow/python/kernel_tests/random:all_files", "//tensorflow/python/ops/distributions:all_files", + "//tensorflow/python/ops/linalg:all_files", "//tensorflow/python/profiler:all_files", "//tensorflow/python/profiler/internal:all_files", "//tensorflow/python/saved_model:all_files", @@ -501,6 +523,7 @@ filegroup( "//tensorflow/tools/api/golden:all_files", "//tensorflow/tools/api/lib:all_files", "//tensorflow/tools/api/tests:all_files", + "//tensorflow/tools/benchmark:all_files", "//tensorflow/tools/build_info:all_files", "//tensorflow/tools/common:all_files", "//tensorflow/tools/compatibility:all_files", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 6919dfe71124ebe861c1649b123e8a714056e45d..ef7eb5a4d16b29aecc34f33cb41dd7cf9450c5f2 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -30,7 +30,10 @@ tf_cuda_library( name = "c_api_internal", srcs = ["c_api.h"], hdrs = ["c_api_internal.h"], - visibility = ["//tensorflow/c:__subpackages__"], + visibility = [ + "//tensorflow:internal", + "//tensorflow/c:__subpackages__", + ], deps = select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 334f867e47800507760eaa71dce91186f646f72d..cd98393e0a5f60c57fc571488a72430ae9cb65cc 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -1799,6 +1799,17 @@ void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, status->status = MessageToBuffer(def, output_graph_def); } +void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, + TF_Buffer* output_op_def, TF_Status* status) { + const OpDef* op_def; + { + mutex_lock l(graph->mu); + status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def); + if (!status->status.ok()) return; + } + status->status = MessageToBuffer(*op_def, output_op_def); +} + TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { return new TF_ImportGraphDefOptions; } @@ -1854,18 +1865,18 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, return; } const int last_node_id = graph->graph.num_node_ids(); - std::vector> return_outputs_vec; - status->status = tensorflow::ImportGraphDef( - opts->opts, def, &graph->graph, &graph->refiner, &return_outputs_vec); + tensorflow::ImportGraphDefResults results; + status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, + &graph->refiner, &results); if (!status->status.ok()) return; for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { auto* node = graph->graph.FindNodeId(i); if (node != nullptr) graph->name_map[node->name()] = node; } - DCHECK_EQ(return_outputs_vec.size(), num_return_outputs); + DCHECK_EQ(results.return_tensors.size(), num_return_outputs); for (int i = 0; i < num_return_outputs; ++i) { - return_outputs[i].oper = ToOperation(return_outputs_vec[i].first); - return_outputs[i].index = return_outputs_vec[i].second; + return_outputs[i].oper = ToOperation(results.return_tensors[i].first); + return_outputs[i].index = results.return_tensors[i].second; } } @@ -1945,11 +1956,11 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, } // TOOD(skyewm): change to OutputTensor - std::vector> return_tensors; + tensorflow::ImportGraphDefResults results; TF_RETURN_IF_ERROR( - ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &return_tensors)); + ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results)); - for (const auto& pair : return_tensors) { + for (const auto& pair : results.return_tensors) { return_nodes->emplace_back(pair.first, pair.second); } return Status::OK(); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index db94828e1a8adb5af7c34500c0675b1a6e93b805..1e8bfdc7b069e306e0e0e0f62a935e9da460cc50 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -118,6 +118,8 @@ typedef enum TF_DataType { TF_HALF = 19, TF_RESOURCE = 20, TF_VARIANT = 21, + TF_UINT32 = 22, + TF_UINT64 = 23, } TF_DataType; // TF_DataTypeSize returns the sizeof() for the underlying type corresponding @@ -862,6 +864,13 @@ TF_CAPI_EXPORT extern void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, TF_Status* status); +// Returns the serialized OpDef proto with name `op_name`, or a bad status if no +// such op exists. This can return OpDefs of functions copied into the graph. +TF_CAPI_EXPORT extern void TF_GraphGetOpDef(TF_Graph* graph, + const char* op_name, + TF_Buffer* output_op_def, + TF_Status* status); + // TF_ImportGraphDefOptions holds options that can be passed to // TF_GraphImportGraphDef. typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; @@ -1144,7 +1153,7 @@ TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef( const void* proto, size_t proto_len, TF_Status* status); // Sets function attribute named `attr_name` to value stored in `proto`. -// If this attribute is already set to another value, it is overriden. +// If this attribute is already set to another value, it is overridden. // `proto` should point to a sequence of bytes of length `proto_len` // representing a binary serialization of an AttrValue protocol // buffer. diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 4db9a90fdc1c00d5a86de7f5f92f29a3ff4d7df9..d5580b658992413ae6f9cb79ef88751ee28ce465 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1465,5 +1465,26 @@ TEST_F(CApiFunctionTest, AppendHash) { ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name()); } +TEST_F(CApiFunctionTest, GetOpDef) { + DefineFunction(func_name_, &func_); + TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Test we can retrieve function OpDef from graph + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Sanity check returned OpDef + string data(static_cast(buffer->data), buffer->length); + OpDef op_def; + op_def.ParseFromString(data); + EXPECT_EQ(op_def.name(), func_name_); + EXPECT_EQ(op_def.input_arg_size(), 1); + EXPECT_EQ(op_def.output_arg_size(), 1); + + TF_DeleteBuffer(buffer); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index c4420290099ee10c89792210dad2604328296515..d220bc5e95f3de22164fa661091d40ce13d8fb11 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -50,6 +51,11 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { +static void ExpectHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(StringPiece(s).contains(expected)) + << "'" << s << "' does not contain '" << expected << "'"; +} + TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); } TEST(CAPI, Status) { @@ -837,6 +843,31 @@ TEST(CAPI, ShapeInferenceError) { TF_DeleteStatus(status); } +TEST(CAPI, GetOpDef) { + TF_Status* status = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + TF_Buffer* buffer = TF_NewBuffer(); + + TF_GraphGetOpDef(graph, "Add", buffer, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + const OpDef* expected_op_def; + TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def)); + string expected_serialized; + expected_op_def->SerializeToString(&expected_serialized); + string actual_string(reinterpret_cast(buffer->data), + buffer->length); + EXPECT_EQ(expected_serialized, actual_string); + + TF_GraphGetOpDef(graph, "MyFakeOp", buffer, status); + EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status)); + ExpectHasSubstr(TF_Message(status), + "Op type not registered 'MyFakeOp' in binary"); + + TF_DeleteBuffer(buffer); + TF_DeleteGraph(graph); + TF_DeleteStatus(status); +} + void StringVectorToArrays(const std::vector& v, std::unique_ptr* ptrs, std::unique_ptr* lens) { diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index e7b9bca5b50e4837534c315b8fa2ca161019d100..b1f7bdaa5420a56386e6983052df20aa976aa867 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/checkpoint_reader.h" #include +#include #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -24,43 +25,43 @@ limitations under the License. #include "tensorflow/core/util/saved_tensor_slice_util.h" namespace tensorflow { - namespace checkpoint { class TensorSliceReader; CheckpointReader::CheckpointReader(const string& filename, TF_Status* out_status) - : reader_(nullptr), v2_reader_(nullptr), var_to_shape_map_ptr_(nullptr) { + : reader_(nullptr), + v2_reader_(nullptr), + var_to_shape_map_(nullptr), + var_to_data_type_map_(nullptr) { // Depending on whether this is a V2 ckpt, initializes "reader_" or // "v2_reader_". std::vector v2_path; if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() && !v2_path.empty()) { - v2_reader_ = - new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */); + v2_reader_.reset( + new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */)); if (!v2_reader_->status().ok()) { Set_TF_Status_from_Status(out_status, v2_reader_->status()); return; } - var_to_shape_map_ptr_ = BuildV2VarToShapeMap(); + auto result = BuildV2VarMaps(); + var_to_shape_map_.swap(result.first); + var_to_data_type_map_.swap(result.second); } else { - reader_ = new TensorSliceReader(filename); + reader_.reset(new TensorSliceReader(filename)); if (!reader_->status().ok()) { Set_TF_Status_from_Status(out_status, reader_->status()); return; } - var_to_shape_map_ptr_ = - new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()); + var_to_shape_map_.reset( + new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap())); + var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap( + reader_->GetVariableToDataTypeMap())); } } -CheckpointReader::~CheckpointReader() { - delete var_to_shape_map_ptr_; - delete reader_; - delete v2_reader_; -} - bool CheckpointReader::HasTensor(const string& name) const { if (reader_ != nullptr) { return reader_->HasTensor(name, nullptr, nullptr); @@ -70,8 +71,14 @@ bool CheckpointReader::HasTensor(const string& name) const { const TensorSliceReader::VarToShapeMap& CheckpointReader::GetVariableToShapeMap() const { - CHECK(var_to_shape_map_ptr_); - return *var_to_shape_map_ptr_; + CHECK(var_to_shape_map_); + return *var_to_shape_map_; +} + +const TensorSliceReader::VarToDataTypeMap& +CheckpointReader::GetVariableToDataTypeMap() const { + CHECK(var_to_data_type_map_); + return *var_to_data_type_map_; } const string CheckpointReader::DebugString() const { @@ -100,7 +107,9 @@ void CheckpointReader::GetTensor( } } -TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() { +std::pair, + std::unique_ptr> +CheckpointReader::BuildV2VarMaps() { CHECK(v2_reader_ != nullptr); CHECK(v2_reader_->status().ok()); @@ -123,18 +132,23 @@ TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() { } // Second pass: adds the entries, ignoring the filtered keys. - TensorSliceReader::VarToShapeMap* var_to_shape_map = - new TensorSliceReader::VarToShapeMap; + std::unique_ptr var_to_shape_map( + new TensorSliceReader::VarToShapeMap); + std::unique_ptr var_to_data_type_map( + new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - (*var_to_shape_map)[v2_reader_->key().ToString()] = - TensorShape(entry.shape()); + string key = v2_reader_->key().ToString(); + (*var_to_shape_map)[key] = TensorShape(entry.shape()); + (*var_to_data_type_map)[key] = DataType(entry.dtype()); } - return var_to_shape_map; // Owned by caller. + // The returned pointers are owned by the caller. + return std::make_pair(std::move(var_to_shape_map), + std::move(var_to_data_type_map)); } } // namespace checkpoint diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h index 1124416380df624f97b3ce2ebaadb04b3c17d341..4de1300a7f66a8b4eb8074819432fd7dd597bb15 100644 --- a/tensorflow/c/checkpoint_reader.h +++ b/tensorflow/c/checkpoint_reader.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_C_CHECKPOINT_READER_H #define TENSORFLOW_C_CHECKPOINT_READER_H +#include +#include + #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" @@ -24,7 +27,6 @@ limitations under the License. #include "tensorflow/core/util/tensor_slice_reader.h" namespace tensorflow { - namespace checkpoint { class TensorSliceReader; @@ -38,15 +40,18 @@ class TensorSliceReader; class CheckpointReader { public: CheckpointReader(const string& filepattern, TF_Status* out_status); - ~CheckpointReader(); bool HasTensor(const string& name) const; const string DebugString() const; - // Returns a map from variable names to its shape. Slices of a partitioned + // Returns a map from variable names to their shapes. Slices of a partitioned // tensor are combined into a single entry. const TensorSliceReader::VarToShapeMap& GetVariableToShapeMap() const; + // Returns a map from variable names to their data types. Slices of a + // partitioned tensor are combined into a single entry. + const TensorSliceReader::VarToDataTypeMap& GetVariableToDataTypeMap() const; + // Attempts to look up the tensor named "name" and stores the found result in // "out_tensor". void GetTensor(const string& name, @@ -54,14 +59,19 @@ class CheckpointReader { TF_Status* out_status) const; private: - // Uses "v2_reader_" to build a "var name -> shape" map; owned by caller. + // Uses "v2_reader_" to build "var name -> shape" and "var name -> data type" + // maps; both owned by caller. // REQUIRES: "v2_reader_ != nullptr && v2_reader_.status().ok()". - TensorSliceReader::VarToShapeMap* BuildV2VarToShapeMap(); + std::pair, + std::unique_ptr > + BuildV2VarMaps(); + + // Invariant: exactly one of "reader_" and "v2_reader_" is non-null. + std::unique_ptr reader_; + std::unique_ptr v2_reader_; - // Invariant: exactly one of "reader_" and "v2_reader_" is non-nullptr. - TensorSliceReader* reader_; // Owned. - BundleReader* v2_reader_; // Owned. - TensorSliceReader::VarToShapeMap* var_to_shape_map_ptr_; // Owned. + std::unique_ptr var_to_shape_map_; + std::unique_ptr var_to_data_type_map_; TF_DISALLOW_COPY_AND_ASSIGN(CheckpointReader); }; diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 52945d32391ddcd9bddb7726ddac68ee1ba9ae58..c77896b80b478cd34d3502e1061a7e76204ba021 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "tf_cuda_cc_test", "tf_cc_test", "tf_copts", "tf_cuda_library", @@ -10,13 +11,15 @@ load( tf_cuda_library( name = "c_api", - srcs = ["c_api.cc"], + srcs = [ + "c_api.cc", + "c_api_internal.h", + ], hdrs = ["c_api.h"], copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ @@ -33,7 +36,22 @@ tf_cuda_library( }), ) -tf_cc_test( +tf_cuda_library( + name = "c_api_internal", + hdrs = ["c_api_internal.h"], + deps = [ + ":c_api", + ":runtime", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework_internal", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_internal", + ], +) + +tf_cuda_cc_test( name = "c_api_test", srcs = ["c_api_test.cc"], deps = [ @@ -53,7 +71,6 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = select({ "//tensorflow:android": [ - ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ @@ -85,3 +102,14 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "tape", + srcs = ["tape.cc"], + hdrs = ["tape.h"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 801d7307494e6585fbb7ee0fa4e6724ebe2c6f94..8359de62b7ff690fec9f6a0e3280f947c62f8b6e 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -51,69 +52,25 @@ string DeviceName(tensorflow::Device* d) { } } // namespace -struct TFE_Context { - explicit TFE_Context(TF_Session* s) : session(s) {} - - // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. - TF_Session* session; - tensorflow::Rendezvous* rendezvous; - - tensorflow::mutex functions_mu; - tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ - tensorflow::OpRegistry::Global(), {}}; - - // One FunctionLibraryRuntime per device. - // func_libs[i] is the FunctionLibraryRuntime corresponding to - // session->devices[i]. - std::unique_ptr pflr; +extern "C" { - std::unordered_map - kernel_cache; +TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; } - tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { - return pflr->GetFLR(d->name()); - } +void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, + size_t proto_len, TF_Status* status) { + TF_SetConfig(&options->session_options, proto, proto_len, status); +} - const std::vector& devices() { return session->devices; } -}; - -struct TFE_TensorHandle { - TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d) - : t(t), d(d) {} - - tensorflow::Tensor t; - // TODO(ashankar): d == nullptr iff local CPU - // This was expedient, but perhaps worth revisiting ('d' should always be a - // valid pointer?) - // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are - // provided with the appropriate TFE_Context. - // - // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a - // TFE_TensorHandle does not outlive the TFE_Context from which it came? - tensorflow::Device* d; -}; - -struct TFE_Op { - TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) - : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} - - bool const is_function() const { return attr_types == nullptr; } - - TFE_Context* ctx; // Must outlive the TFE_Op. - const string name; - tensorflow::AttrBuilder attrs; - const tensorflow::AttrTypeMap* attr_types; - std::vector inputs; - std::vector input_devices; - tensorflow::Device* device; -}; +void TFE_ContextOptionsSetDevicePlacementPolicy( + TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { + options->policy = policy; +} -extern "C" { +void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } -TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { +TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { TF_Graph* graph = TF_NewGraph(); - TF_Session* session = TF_NewSession(graph, opts, status); + TF_Session* session = TF_NewSession(graph, &opts->session_options, status); if (status->status.ok()) { if (session->device_mgr == nullptr || session->devices.empty()) { status->status = tensorflow::errors::InvalidArgument( @@ -128,9 +85,10 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { } TFE_Context* ret = new TFE_Context(session); + ret->policy = opts->policy; ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime( - ret->session->device_mgr, opts->options.env, TF_GRAPH_DEF_VERSION, - &ret->func_lib_def, {})); + ret->session->device_mgr, opts->session_options.options.env, + TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {})); ret->rendezvous = new tensorflow::IntraProcessRendezvous(ret->session->device_mgr); @@ -330,6 +288,20 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, return ret; } +TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, + const char* op_or_function_name, + const char* attr_name, unsigned char* is_list, + TF_Status* status) { + TF_AttrType ret; + TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status); + if (!status->status.ok()) { + return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType. + } + ret = TFE_OpGetAttrType(op, attr_name, is_list, status); + TFE_DeleteOp(op); + return ret; +} + void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { op->attrs.Set(attr_name, value); } @@ -451,8 +423,10 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, namespace { tensorflow::Status ValidateInputTypeAndPlacement( - tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op, - const tensorflow::OpKernel* kernel) { + TFE_Context* ctx, tensorflow::Device* host_device, + tensorflow::Device* op_device, TFE_Op* op, + const tensorflow::OpKernel* kernel, + std::vector* copied_tensors) { const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); if (memtypes.size() != op->inputs.size()) { return tensorflow::errors::InvalidArgument( @@ -464,11 +438,50 @@ tensorflow::Status ValidateInputTypeAndPlacement( const tensorflow::Device* actual_device = op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; if (expected_device != actual_device) { - return tensorflow::errors::InvalidArgument( - "cannot compute ", op->name, " as input #", i, - " was expected to be on ", expected_device->name(), - " but is actually on ", actual_device->name(), - " (operation running on ", op_device->name(), ")"); + switch (ctx->policy) { + case TFE_DEVICE_PLACEMENT_EXPLICIT: + // TODO(xpan): See if we could bubble python related error up + // to python level. + return tensorflow::errors::InvalidArgument( + "Tensors on conflicting devices:" + " cannot compute ", + op->name, " as input #", i, " was expected to be on ", + expected_device->name(), " but is actually on ", + actual_device->name(), " (operation running on ", + op_device->name(), ")", + " Tensors can be copied explicitly using .gpu() or .cpu()," + " or transparently copied by using tfe.enable_eager_execution(" + "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices" + " may slow down your model"); + case TFE_DEVICE_PLACEMENT_WARN: + LOG(WARNING) << "before computing " << op->name << " input #" << i + << " was expected to be on " << expected_device->name() + << " but is actually on " << actual_device->name() + << " (operation running on " << op_device->name() + << "). This triggers a copy which can be a performance " + "bottleneck."; + break; + case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing. + break; + } + // We are only here if the policy is warn or silent copies, so we should + // trigger a copy. + TFE_TensorHandle original{op->inputs[i], op->input_devices[i]}; + TF_Status* s = TF_NewStatus(); + TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( + &original, ctx, expected_device->name().c_str(), s); + if (!s->status.ok()) { + tensorflow::Status status = s->status; + delete s; + return tensorflow::errors::Internal( + "Failed copying input tensor from ", actual_device->name(), " to ", + expected_device->name(), " in order to run ", op->name, ": ", + status.error_message()); + } + op->inputs[i] = copied_tensor->t; + copied_tensors->push_back(copied_tensor); + op->input_devices[i] = copied_tensor->d; + delete s; } if (op->inputs[i].dtype() != kernel->input_type(i)) { return tensorflow::errors::InvalidArgument( @@ -511,10 +524,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } - status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op, - kernel->kernel()); + std::vector copied_tensors; + status->status = ValidateInputTypeAndPlacement( + ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors); output_memory_types = &kernel->kernel()->output_memory_types(); if (!status->status.ok()) { + for (auto* t : copied_tensors) { + TFE_DeleteTensorHandle(t); + } return; } // WARNING: kernel->Run utilizes the FunctionLibraryRuntime @@ -526,6 +543,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // sense for FunctionLibraryRuntime to ensure thread-safe access to // FunctionLibraryDefinition?). status->status = kernel->Run(&op->inputs, &outputs); + for (auto* t : copied_tensors) { + TFE_DeleteTensorHandle(t); + } if (!status->status.ok()) return; *num_retvals = std::min(*num_retvals, outputs.size()); for (int i = 0; i < *num_retvals; ++i) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index a4f7d308fbb4008d00bd97abf40c9ead5fdb1986..865580c5f3a823d9cf49fe460bd007e3b3b88767 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -43,14 +43,46 @@ limitations under the License. extern "C" { #endif +typedef struct TFE_ContextOptions TFE_ContextOptions; + +// Return a new options object. +TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(); + +// Set the config in TF_ContextOptions.options. +// config should be a serialized tensorflow.ConfigProto proto. +// If config was not parsed successfully as a ConfigProto, record the +// error information in *status. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status); + +// Controls how to act when we try to run an operation on a given device but +// some input tensors are not on that device. +typedef enum TFE_ContextDevicePlacementPolicy { + // The default: running operations with input tensors on the wrong device will + // fail. + TFE_DEVICE_PLACEMENT_EXPLICIT = 0, + // Copy the tensor to the right device but log a warning. + TFE_DEVICE_PLACEMENT_WARN = 1, + // Silently copy the tensor, which has a performance cost since the + // operation will be blocked till the copy completes. + TFE_DEVICE_PLACEMENT_SILENT = 2, +} TFE_ContextDevicePlacementPolicy; + +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( + TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); + +// Destroy an options object. +TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); + // "Context" under which operations/functions are executed. It encapsulates // things like the available devices, resource manager etc. // // TODO(ashankar): Merge with TF_Session? typedef struct TFE_Context TFE_Context; -TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, - TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext( + const TFE_ContextOptions* opts, TF_Status* status); TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status); TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status); @@ -107,6 +139,12 @@ TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_St TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status); +// Get an attribute type given an op name; a fusion of TFE_NewOp and +// TFE_OpGetAttrType for use from Python without the overhead of the individual +// calls and memory management of TFE_Op. +TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( + TFE_Context* ctx, const char* op_or_function_name, const char* attr_name, + unsigned char* is_list, TF_Status* status); TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..0971e2ab2fe98cc8bf6f631f41d5adce90ee7051 --- /dev/null +++ b/tensorflow/c/eager/c_api_internal.h @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ + +#include "tensorflow/c/eager/c_api.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/runtime.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +struct TFE_ContextOptions { + TF_SessionOptions session_options; + TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT}; +}; + +struct TFE_Context { + explicit TFE_Context(TF_Session* s) : session(s) {} + + TFE_ContextDevicePlacementPolicy policy; + + // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. + TF_Session* session; + tensorflow::Rendezvous* rendezvous; + + tensorflow::mutex functions_mu; + tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ + tensorflow::OpRegistry::Global(), {}}; + + // One FunctionLibraryRuntime per device. + // func_libs[i] is the FunctionLibraryRuntime corresponding to + // session->devices[i]. + std::unique_ptr pflr; + + std::unordered_map + kernel_cache; + + tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { + return pflr->GetFLR(d->name()); + } + + const std::vector& devices() { return session->devices; } +}; + +struct TFE_TensorHandle { + TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d) + : t(t), d(d) {} + + tensorflow::Tensor t; + // TODO(ashankar): d == nullptr iff local CPU + // This was expedient, but perhaps worth revisiting ('d' should always be a + // valid pointer?) + // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are + // provided with the appropriate TFE_Context. + // + // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a + // TFE_TensorHandle does not outlive the TFE_Context from which it came? + tensorflow::Device* d; +}; + +struct TFE_Op { + TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) + : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} + + bool const is_function() const { return attr_types == nullptr; } + + TFE_Context* ctx; // Must outlive the TFE_Op. + const tensorflow::string name; + tensorflow::AttrBuilder attrs; + const tensorflow::AttrTypeMap* attr_types; + std::vector inputs; + std::vector input_devices; + tensorflow::Device* device; +}; + +#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 72e0fe8a1565a9a717c01aed83044cab2dd2dfbc..4af91b8853d0e85570bad136752a9d0a04b87da5 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -62,10 +62,10 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); tensorflow::testing::StartTiming(); @@ -84,10 +84,10 @@ BENCHMARK(BM_InitOp); void BM_Execute(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -109,9 +109,9 @@ BENCHMARK(BM_Execute); TEST(CAPI, Context) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TF_DeviceList* devices = TFE_ContextListDevices(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -150,9 +150,9 @@ TEST(CAPI, TensorHandle) { TEST(CAPI, TensorHandleCopyBetweenDevices) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status.get()); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); @@ -216,12 +216,58 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleSilentCopy) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + const int num_devices = TF_DeviceListCount(devices); + + // Disable the test if no GPU is present. + if (num_devices > 1) { + const int device_to_use = 1; + const string name(TF_DeviceListName(devices, device_to_use, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_TensorHandle* hgpu = + TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); + TFE_OpSetDevice(matmul, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); +} + TEST(CAPI, Execute) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -285,10 +331,10 @@ string MatMulFunction() { TEST(CAPI, FunctionDefAndExecute) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); string function_def = MatMulFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), @@ -326,10 +372,10 @@ TEST(CAPI, FunctionDefAndExecute) { void BM_ExecuteFunction(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); string function_def = MatMulFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), @@ -406,10 +452,10 @@ TEST(CAPI, Variables) { // Variables use resource handles, so this is really a test for resource // tensor handling. TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -446,10 +492,10 @@ TEST(CAPI, Variables) { void BM_ReadVariable(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); diff --git a/tensorflow/c/eager/tape.cc b/tensorflow/c/eager/tape.cc new file mode 100644 index 0000000000000000000000000000000000000000..464612a81ebda428f5582b6927f3a3b00a5aa6f5 --- /dev/null +++ b/tensorflow/c/eager/tape.cc @@ -0,0 +1,102 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/tape.h" + +namespace tensorflow { +namespace eager { + +bool GradientTape::ShouldRecord(gtl::ArraySlice tensor_ids) { + for (int64 i : tensor_ids) { + if (tensor_tape_.find(i) != tensor_tape_.end()) { + return true; + } + } + return false; +} + +void GradientTape::Watch(int64 tensor_id) { + tensor_tape_.emplace(tensor_id, -1); +} + +void GradientTape::RecordOperation( + const string& op_type, gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, void* backward_function, + const std::function& backward_function_deleter) { + if (!ShouldRecord(input_tensor_id)) { + backward_function_deleter(); + return; + } + std::vector ids; + ids.reserve(input_tensor_id.size()); + for (int64 i : input_tensor_id) { + tensor_usage_[i]++; + ids.push_back(i); + } + const int64 op_id = next_op_id_++; + std::vector tensors; + tensors.reserve(output_tensors.size()); + for (const TapeTensor& o : output_tensors) { + // Note: the tensor can have already been watched and hence be in the tape, + // so we cannot check that we're inserting it here. + tensor_tape_[o.id] = op_id; + tensor_usage_[o.id] = 1; + tensors.push_back(o); + } + op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function, + backward_function_deleter}; +} + +void GradientTape::DeleteTrace(int64 tensor_id) { + auto it = tensor_usage_.find(tensor_id); + if (it == tensor_usage_.end()) { + return; + } + it->second--; + if (it->second != 0) { + return; + } + tensor_usage_.erase(it); + auto tensor_op_it = tensor_tape_.find(tensor_id); + if (tensor_op_it == tensor_tape_.end()) { + return; + } + const int64 op_id = tensor_op_it->second; + if (op_id == -1) { + // Do not delete watched tensors. + return; + } + tensor_tape_.erase(tensor_op_it); + auto op_it = op_tape_.find(op_id); + CHECK(op_it != op_tape_.end()); + for (const auto& output : op_it->second.output_tensor_info) { + if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + // Found a usage for an output, so cannot delete the op. + return; + } + } + for (int64 id : op_it->second.input_tensor_id) { + DeleteTrace(id); + } + op_it->second.backward_function_deleter(); + op_tape_.erase(op_it); +} + +std::pair GradientTape::Export() { + return {std::move(tensor_tape_), std::move(op_tape_)}; +} + +} // namespace eager +} // namespace tensorflow diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h new file mode 100644 index 0000000000000000000000000000000000000000..df51f300eb61d54cb1e06d5a58a9b10e834f73c4 --- /dev/null +++ b/tensorflow/c/eager/tape.h @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TAPE_H_ +#define TENSORFLOW_C_EAGER_TAPE_H_ + +// Language-agnostic gradient tape. Does not perform backpropagation, just +// maintains the data structures required to do so. + +#include +#include +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace eager { + +// Information about a tensor. +struct TapeTensor { + int64 id; // Expected to be unique in the lifetime of this process. + DataType dtype; + TensorShape shape; +}; + +// Represents an entry in the tape. +struct OpTapeEntry { + string op_type; + std::vector output_tensor_info; + std::vector input_tensor_id; + + // TODO(apassos) consider narrowing down this interface. + void* backward_function; + + // Should be called before deleting the backward function. TODO(apassos) use + // unique_ptrs to ensure this happens. + std::function backward_function_deleter; +}; + +// Map from tensor_id to internally-defined operation-id of the operation which +// produced this tensor. A value of -1 means that the tensor was directly +// watched and not the result of any operation in the tape. +using TensorTape = std::unordered_map; + +// Map from operation-id to tape entry. +using OpTape = std::unordered_map; + +// Traces the execution of operations, doing eager garbage collection, and +// exporting a full trace so other code can do backpropagation. Not thread-safe. +class GradientTape { + public: + GradientTape() {} + + bool ShouldRecord(gtl::ArraySlice tensor_ids); + + void Watch(int64 tensor_id); + + void RecordOperation(const string& op_type, + gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, + void* backward_function, + const std::function& backward_function_deleter); + + void DeleteTrace(int64 tensor_id); + + // Note: it is only valid to call Export once per tape, and after calling + // export the tape is no longer valid (i.e. calls to ShouldRecord, Watch, + // Record, and Delete have undefined behavior). + std::pair Export(); + + private: + TensorTape tensor_tape_; + OpTape op_tape_; + int64 next_op_id_{0}; + + // Map from tensor id to number of remaining usages (i.e. how many entries in + // the tape refer to it); to aid in tape garbage collection. + std::unordered_map tensor_usage_; +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TAPE_H_ diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc index 2423d83dda93938aa1a2ba0ed0ed7356bd65d39f..d2d887f32c44af5980b50785f282187d0f6fcff4 100644 --- a/tensorflow/c/while_loop_test.cc +++ b/tensorflow/c/while_loop_test.cc @@ -318,7 +318,7 @@ TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) { // TODO(skyewm): this error message could be more informative. Add explicit // checks for this case in the while loop implementation? ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) { @@ -358,7 +358,7 @@ TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) { // TODO(skyewm): this error message could be more informative. Add explicit // checks for this case in the while loop implementation? ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } // TODO(skyewm): enable this when it works (currently segfaults!) @@ -389,7 +389,7 @@ TEST_F(CApiWhileLoopTest, WrongGraph) { params_->body_outputs[0] = inputs_[0]; // TODO(skyewm): improve error message ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } TEST_F(CApiWhileLoopTest, BadTypes) { diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index ac288b1d834d267f5bab887f45de8173e31f88ea..d7446b9560fd7dc8377ea3710641906b274313a9 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define _USE_MATH_DEFINES +#include + #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/math_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -200,8 +203,8 @@ Status TanhGrad(const Scope& scope, const Operation& op, // evaluated. Scope grad_scope = scope.WithControlDependencies(grad); auto y = ConjugateHelper(grad_scope, op.output(0)); - grad_outputs->push_back(internal::TanhGrad(scope, y, grad)); - return scope.status(); + grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad)); + return grad_scope.status(); } REGISTER_GRADIENT_OP("Tanh", TanhGrad); @@ -256,8 +259,8 @@ Status SigmoidGrad(const Scope& scope, const Operation& op, // evaluated. Scope grad_scope = scope.WithControlDependencies(grad); auto y = ConjugateHelper(grad_scope, op.output(0)); - grad_outputs->push_back(internal::SigmoidGrad(scope, y, grad)); - return scope.status(); + grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad)); + return grad_scope.status(); } REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); @@ -484,7 +487,7 @@ Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, auto grad = grad_inputs[0]; auto zeros = ZerosLike(scope, grad); auto gx_1 = Where3(scope, comparator, grad, zeros); - auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros); + auto gx_2 = Where3(scope, comparator, zeros, grad); return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); } @@ -696,15 +699,32 @@ Status MeanGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Mean", MeanGrad); +Status ErfGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto grad = grad_inputs[0]; + auto two_over_root_pi = Cast(scope, Const(scope, 2 / std::sqrt(M_PI)), + grad.type()); + Scope grad_scope = scope.WithControlDependencies(grad); + auto x = ConjugateHelper(grad_scope, op.input(0)); + // grad * 2/sqrt(pi) * exp(-x**2) + auto dx = Mul(grad_scope, + Mul(grad_scope, grad, two_over_root_pi), + Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x)))); + grad_outputs->push_back(dx); + return grad_scope.status(); +} +REGISTER_GRADIENT_OP("Erf", ErfGrad); + Status LgammaGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { auto grad = grad_inputs[0]; Scope grad_scope = scope.WithControlDependencies(grad); auto x = ConjugateHelper(grad_scope, op.input(0)); - auto dx = Mul(scope, grad, Digamma(scope, x)); + auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x)); grad_outputs->push_back(dx); - return scope.status(); + return grad_scope.status(); } REGISTER_GRADIENT_OP("Lgamma", LgammaGrad); diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index a174f223ad59b0a111b3d13cb59fb2b13a0095b0..6313f41da5e5f9cf88be4c8a84408a8df77f0e25 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -64,7 +64,9 @@ class CWiseUnaryGradTest : public ::testing::Test { IMAG, CONJ, COMPLEX, - ANGLE + ANGLE, + LGAMMA, + ERF }; template @@ -168,6 +170,12 @@ class CWiseUnaryGradTest : public ::testing::Test { case ANGLE: y = Angle(scope_, x); break; + case LGAMMA: + y = Lgamma(scope_, x); + break; + case ERF: + y = Erf(scope_, x); + break; } float max_error; @@ -503,6 +511,42 @@ TEST_F(CWiseUnaryGradTest, Angle) { TestCWiseGrad(ANGLE, x_fn); } +TEST_F(CWiseUnaryGradTest, Lgamma) { + auto x_fn = [this](const int i) { + return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5}); + }; + TestCWiseGrad(LGAMMA, x_fn); +} + +TEST_F(CWiseUnaryGradTest, Lgamma_Complex) { + auto x_fn = [this](const int i) { + return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}}); + }; + // TODO(kbsriram) + // Add test when the lgamma kernel supports complex numbers + if (false) { + TestCWiseGrad(LGAMMA, x_fn); + } +} + +TEST_F(CWiseUnaryGradTest, Erf) { + auto x_fn = [this](const int i) { + return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3}); + }; + TestCWiseGrad(ERF, x_fn); +} + +TEST_F(CWiseUnaryGradTest, Erf_Complex) { + auto x_fn = [this](const int i) { + return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}}); + }; + // TODO(kbsriram) + // Add test when the erf kernel supports complex numbers + if (false) { + TestCWiseGrad(ERF, x_fn); + } +} + class MathGradTest : public ::testing::Test { protected: MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} @@ -821,17 +865,5 @@ TEST_F(NaryGradTest, Minimum) { RunTest(x, x_init_value, y, shape); } -TEST_F(NaryGradTest, Lgamma) { - TensorShape shape({3, 2}); - auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); - auto y = Lgamma(scope_, x); - // Select values to avoid instability when computing finite differences. - // Ref: https://en.wikipedia.org/wiki/File:Gamma_plot.svg - Tensor x_init_value = - test::AsTensor({-3.5f, -2.5f, -1.5f, 1.0f, 2.0f, 3.5f}, {3, 2}); - RunTest(x, x_init_value, y, shape); - // TODO(suharshs): add test case for complex values -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 1cc7cf3f2021ede8269368aa46007b5ceaace606..d29ad3ebcbe29087d5572b51c7713e0c98d0d840 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -56,6 +56,7 @@ cc_library( ":constants", ] + if_not_mobile([ "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index eac8da0ab1b05e7d5cc8d27a1e1ffecc85515cdb..2b8cc6024cb85e4f6269313927ff66d1d9a1cf79 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -97,11 +97,15 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client, &computation, &compile_result->has_context_arg)); - if (!flags.debug_dir.empty()) { + if (!flags.out_session_module.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); - string file = io::JoinPath(flags.debug_dir, "tfcompile_xla_module.pb"); - TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), file, *module)); + // Serialize the SessionModule deterministically so that all the outputs of + // a tf_library genrule are deterministic. + string proto; + TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto)); + TF_RETURN_IF_ERROR( + WriteStringToFile(Env::Default(), flags.out_session_module, proto)); } xla::cpu::CpuAotCompilationOptions aot_opts( flags.target_triple, flags.target_cpu, flags.target_features, diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc index 5aff10346fa368f214436d1d0837505ffbbc771e..7c2f27e550d44c2487f91acf1029c962ac3f5d01 100644 --- a/tensorflow/compiler/aot/flags.cc +++ b/tensorflow/compiler/aot/flags.cc @@ -33,9 +33,6 @@ void AppendMainFlags(std::vector* flag_list, MainFlags* flags) { "fetch nodes will be dumped to stdout in a comma-separated list. " "Typically used to format arguments for other tools, e.g. " "freeze_graph."}, - {"debug_dir", &flags->debug_dir, - "Specifies a directory to dump debugging information, including " - "rewritten graphs and the XLA HLO module."}, // Flags controlling the XLA ahead-of-time compilation, that correspond to // the fields of xla::cpu::CpuAotCompilationOptions. // @@ -64,6 +61,8 @@ void AppendMainFlags(std::vector* flag_list, MainFlags* flags) { "namespaces are given, within the global namespace."}, {"out_object", &flags->out_object, "Output object file name."}, {"out_header", &flags->out_header, "Output header file name."}, + {"out_session_module", &flags->out_session_module, + "Output session module proto."}, {"gen_name_to_index", &flags->gen_name_to_index, "Generate name-to-index data for Lookup{Arg,Result}Index methods."}, {"gen_program_shape", &flags->gen_program_shape, diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h index 3246dbf95c8a60130af91bc3891b15829aa5e638..3519659e3af7cd345f30080a07ce91fb858623fb 100644 --- a/tensorflow/compiler/aot/flags.h +++ b/tensorflow/compiler/aot/flags.h @@ -29,7 +29,6 @@ struct MainFlags { string graph; string config; bool dump_fetch_nodes = false; - string debug_dir; string target_triple; string target_cpu; string target_features; @@ -37,6 +36,7 @@ struct MainFlags { string cpp_class; string out_object; string out_header; + string out_session_module; // C++ codegen options bool gen_name_to_index = false; diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index cfde5651c679f09e22a063f181ccc2bcbd7e5653..6b037f276ad1d6771b904bb970f45f32ae9531b8 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -180,33 +180,6 @@ TEST(TFCompileTest, Gather) { } EXPECT_EQ(gather_const.result0_data(), gather.results()[0]); } - - // Bad indices returns an error. - { - const float params[4] = {1, 2, 3, 4}; - std::copy(params + 0, params + 4, gather.arg0_data()); - const int32 indices[2] = {1, 4}; - std::copy(indices + 0, indices + 2, gather.arg1_data()); - EXPECT_FALSE(gather.Run()); - EXPECT_EQ(gather.error_msg(), "Invalid index for gather"); - } - - // Try a successful gather again, after the error, to ensure the error state - // is cleared. - { - const float params[4] = {1, 2, 3, 4}; - std::copy(params + 0, params + 4, gather.arg0_data()); - const int32 indices[2] = {1, 3}; - std::copy(indices + 0, indices + 2, gather.arg1_data()); - EXPECT_TRUE(gather.Run()); - EXPECT_EQ(gather.error_msg(), ""); - const float results[2] = {2, 4}; - for (int i = 0; i < 2; ++i) { - EXPECT_EQ(gather.result0(i), results[i]); - EXPECT_EQ(gather.result0_data()[i], results[i]); - } - EXPECT_EQ(gather.result0_data(), gather.results()[0]); - } } TEST(TFCompileTest, MatMul2) { diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 461a9315c58a33bcb96d29090036a02206385580..2adb1dc65ed75a734e517141ecbb7a0ef2323ee4 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -165,6 +165,34 @@ def tf_library(name, graph, config, tags=tags, ) + # Rule that runs tfcompile to produce the SessionModule proto, useful for + # debugging. TODO(b/64813587): Once the SessionModule proto is + # deterministic, move this into the main rule above. + session_module_pb = name + "_session_module.pb" + native.genrule( + name=(name + "_session_module"), + srcs=[ + tfcompile_graph, + config, + ], + outs=[ + session_module_pb, + ], + cmd=("$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_session_module=$(@D)/" + session_module_pb + + " " + (tfcompile_flags or "")), + tools=[tfcompile_tool], + visibility=visibility, + testonly=testonly, + local=1, + tags=tags, + ) + # The cc_library rule packaging up the header and object file, and needed # kernel implementations. need_xla_data_proto = (tfcompile_flags and @@ -186,8 +214,6 @@ def tf_library(name, graph, config, "//tensorflow/compiler/xla:xla_data_proto", ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually needed. - "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int32", - "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int64", "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx", @@ -295,7 +321,6 @@ def tf_library(name, graph, config, tags=tags, ) - def target_llvm_triple(): """Returns the target LLVM triple to be used for compiling the target.""" # TODO(toddw): Add target_triple for other targets. For details see: diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index bf63b7e5016c33b47a8839e03894930203ab0654..bf7d9cf14d10f41aa48ea594a8d63db97b9973e1 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -33,6 +33,7 @@ cc_library( deps = [ ":xla_cpu_device", ":xla_cpu_jit", + "//tensorflow/compiler/plugin", ] + if_cuda_is_configured([ ":xla_gpu_device", ":xla_gpu_jit", diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index ded01ff6530745dafc22a259576e1f6816d20c02..27c5da08c112664d361b5f969d100eed7b9df65c 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -110,7 +110,7 @@ xla::StatusOr XlaAllocator::Allocate( Status XlaAllocator::RegisterArgument(const Tensor* t) { void* data = - reinterpret_cast(const_cast(t->tensor_data().data())); + reinterpret_cast(const_cast(t->tensor_data().data())); TF_RET_CHECK(data != nullptr); tensors_[data] = *t; return Status::OK(); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index db2ed16f95e71f5055d0cb2fce3cda8a7f2976d9..78d0aa86a8fae9a0c6035bdc579ef800337df917 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -560,6 +560,7 @@ Status MarkForCompilationPass::RunImpl( name = strings::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); + VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; } } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 579ce415c5c3c4951be1596a37d47b7930bcf4fb..b3d258aea177fbefa4bae51d8156da2ff86c9032 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -144,8 +144,8 @@ TEST(XlaCompilationTest, UnsupportedTypes) { Node* a = ops::SourceOp( "Const", builder.opts() .WithName("A") - .WithAttr("dtype", DT_COMPLEX64) - .WithAttr("value", Tensor(DT_COMPLEX64, TensorShape()))); + .WithAttr("dtype", DT_COMPLEX128) + .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape()))); Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); TF_EXPECT_OK(builder.ToGraph(graph.get())); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 57b9d6b56bca23e94dc172dce2412ed151643318..e238252751e677eb947f6df03e3b2f2e948ffe19 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -39,9 +39,9 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, - DEVICE_CPU_XLA_JIT, options, name_prefix, - &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create( + "Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, name_prefix, + /*register_device_for_compilation=*/true, &device)); devices->push_back(device.release()); return Status::OK(); } @@ -50,8 +50,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kAllXlaCpuTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 888461611fee6dd78d086ca9da67da40d515bca1..d4d8fe1c1d575b4e35d624621cc709e3a16569d5 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" @@ -107,18 +108,21 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, - const string& name_prefix, std::unique_ptr* device) { + const string& name_prefix, bool register_device_for_compilation, + std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; - // These are no-ops if they have already been done previously for - // this device_name/compilation_device_name pair. - XlaOpRegistry::DeviceRegistration registration; - registration.compilation_device_name = jit_device_name; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; - registration.compile_resource_ops = true; - XlaOpRegistry::RegisterCompilationDevice(device_name, registration); + if (register_device_for_compilation) { + // These are no-ops if they have already been done previously for + // this device_name/compilation_device_name pair. + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = jit_device_name; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + XlaOpRegistry::RegisterCompilationDevice(device_name, registration); + } auto platform = se::MultiPlatformManager::PlatformWithName(platform_name); if (!platform.ok()) { @@ -158,7 +162,8 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, const Metadata** metadata) { - XlaDevice* xla_device = dynamic_cast(ctx->device()); + XlaDevice* xla_device = + dynamic_cast(ctx->device()->UnderlyingDevice()); if (xla_device == nullptr) { return errors::Internal( "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(), @@ -236,7 +241,8 @@ void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { // When TraceMe profiling is off (which is the default), the // following TraceMe constructor is simply a conditional test of // false value. Measurements show that its overhead is negligible. - port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string()); + port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), + op_kernel->IsExpensive()); op_kernel->Compute(context); } @@ -244,7 +250,8 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); - port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string()); + port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), + op_kernel->IsExpensive()); op_kernel->ComputeAsync(context, done); } diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 0d90b8b692896d8addf5ffead3980a5bf640c85c..d2ec38293c429f04f088bf3726ba97eb4e4b0dba 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -74,6 +74,7 @@ class XlaDevice : public LocalDevice { static Status Create(const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, const string& name_prefix, + bool register_device_for_compilation, std::unique_ptr* device); XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 4474d8f4eb06afa78ea36332a8cc58f9d240c1b0..2326070358d67c0cf30ef17fab5c93862cd8932c 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -39,9 +39,9 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr device; - Status status = - XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, - name_prefix, &device); + Status status = XlaDevice::Create( + "CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, + /*register_device_for_compilation=*/true, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; @@ -55,8 +55,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kAllXlaGpuTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 4e4cbe200a21b6584f0fefb4cf43874fb213e244..2614deefd8823dcb8f38e9e22ae4e78145d0d96a 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -42,9 +42,9 @@ Status XlaInterpreterDeviceFactory::CreateDevices( (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, - DEVICE_INTERPRETER_XLA_JIT, options, - name_prefix, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create( + "Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, + options, name_prefix, /*register_device_for_compilation=*/true, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..c1edf2448c54ffddd7b70dcdfb1609080ca81b65 --- /dev/null +++ b/tensorflow/compiler/plugin/BUILD @@ -0,0 +1,56 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Configuration file for an XLA plugin. + + please don't check in changes to this file. to prevent changes appearing + in git status, use: + + git update-index --assume-unchanged tensorflow/compiler/plugin/BUILD + + To add additional devices to the XLA subsystem, add targets to the + dependency list in the 'plugin' target. For instance: + + deps = ["//tensorflow/compiler/plugin/example:plugin_lib"], + + ** Please don't remove this file - it is supporting some 3rd party plugins ** +""" + +licenses(["notice"]) + +package( + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "plugin", + deps = [ + #"//tensorflow/compiler/plugin/example:example_lib", + ], +) + +#----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/plugin/README.md b/tensorflow/compiler/plugin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9dd0d2bdab5e2c990fd547cef4b657253c545715 --- /dev/null +++ b/tensorflow/compiler/plugin/README.md @@ -0,0 +1,16 @@ +3rd party XLA devices +--------------------- + +This directory is intended as a place for 3rd party XLA devices which are _not_ +integrated into the public repository. + +By adding entries to the BUILD target in this directory, a third party device +can be included as a dependency of the JIT subsystem. + +For integration into the unit test system, see the files: + +- tensorflow/compiler/tests/plugin.bzl +- tensorflow/compiler/xla/tests/plugin.bzl + + +- diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index c8269b3d5b7c63c4ee82cdfa9244a3cd6308988d..0ff99c5156ded2ae05c6976e3da8f31fce32f8f2 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -23,6 +23,10 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) generate_backend_suites() @@ -97,9 +101,13 @@ tf_xla_py_test( size = "small", srcs = ["binary_ops_test.py"], shard_count = 5, + tags = [ + "optonly", # Times out frequently in fastbuild mode. + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", + "//tensorflow/python:bitwise_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", @@ -179,6 +187,7 @@ tf_xla_py_test( "noasan", "nomsan", "notsan", + "optonly", # Times out frequently in fastbuild mode. ], deps = [ ":xla_test", @@ -208,11 +217,6 @@ tf_xla_py_test( name = "slice_ops_test", size = "small", srcs = ["slice_ops_test.py"], - # TODO(b/62962492): Test fails with assertion error. - tags = [ - "manual", - "notap", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -460,7 +464,7 @@ tf_xla_py_test( tf_xla_py_test( name = "unary_ops_test", - size = "small", + size = "medium", srcs = ["unary_ops_test.py"], deps = [ ":xla_test", @@ -509,12 +513,8 @@ tf_xla_py_test( tf_xla_py_test( name = "gather_test", - size = "small", + size = "medium", srcs = ["gather_test.py"], - # Gather needs CustomCall on CPU, which is not available in normal - # (not precompiled) TensorFlow. The flag below excludes the CPU - # backend. - disabled_backends = "cpu", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -585,11 +585,12 @@ cc_library( tf_cuda_cc_test( name = "randomized_tests", + size = "large", # This test is randomized, so only run it if explicitly requested. tags = [ "manual", "notap", - ], + ] + tf_cuda_tests_tags(), deps = [":randomized_tests_library"], ) diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index c2ce121348da034efe002dd8db0f5b0703324a41..ec547e16cd9c91a1e25bc963b9a3cafddf7326cd 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -46,7 +46,9 @@ class ArgMinMaxTest(xla_test.XLATestCase): self.assertAllEqual(result, expected) def testArgMinMax(self): - for dtype in self.numeric_types: + # Complex numbers do not support argmin/argmax. + minmax_types = set(self.numeric_types) - set(self.complex_types) + for dtype in minmax_types: self._assertOpOutputMatchesExpected( lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), np.array([1, 10, 27, 3, 3, 4], dtype=dtype), diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 792c01327c2cb4f159af7b60baf3a5d1b124acb6..d412c572ae16b84c2434819aa0a2d881defef5f9 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops @@ -45,6 +46,10 @@ class BinaryOpsTest(XLATestCase): equality_test = self.assertAllClose equality_test(result, expected, rtol=1e-3) + def _testSymmetricBinary(self, op, a, b, expected, equality_test=None): + self._testBinary(op, a, b, expected, equality_test) + self._testBinary(op, b, a, expected, equality_test) + def ListsAreClose(self, result, expected, rtol): """Tests closeness of two lists of floats.""" self.assertEqual(len(result), len(expected)) @@ -89,6 +94,15 @@ class BinaryOpsTest(XLATestCase): dtype(4), expected=np.array([[16], [81]], dtype=dtype)) + atan2_supported = self.device == "XLA_GPU" + if atan2_supported: + self._testBinary( + math_ops.atan2, + np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), + np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), + expected=np.array( + [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) + self._testBinary( gen_math_ops._reciprocal_grad, np.array([4, -3, -2, 1], dtype=dtype), @@ -193,6 +207,32 @@ class BinaryOpsTest(XLATestCase): np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -1, 0, -4, 2], dtype=dtype)) + self._testSymmetricBinary( + bitwise_ops.bitwise_and, + np.array([0b1, 0b101, 0b1000], dtype=dtype), + np.array([0b0, 0b101, 0b1001], dtype=dtype), + expected=np.array([0b0, 0b101, 0b1000], dtype=dtype)) + self._testSymmetricBinary( + bitwise_ops.bitwise_or, + np.array([0b1, 0b101, 0b1000], dtype=dtype), + np.array([0b0, 0b101, 0b1001], dtype=dtype), + expected=np.array([0b1, 0b101, 0b1001], dtype=dtype)) + + lhs = np.array([0, 5, 3, 14], dtype=dtype) + rhs = np.array([5, 0, 7, 11], dtype=dtype) + self._testBinary( + bitwise_ops.left_shift, lhs, rhs, + expected=np.left_shift(lhs, rhs)) + self._testBinary( + bitwise_ops.right_shift, lhs, rhs, + expected=np.right_shift(lhs, rhs)) + + if dtype in [np.int8, np.int16, np.int32, np.int64]: + lhs = np.array([-1, -5, -3, -14], dtype=dtype) + rhs = np.array([5, 0, 1, 11], dtype=dtype) + self._testBinary( + bitwise_ops.right_shift, lhs, rhs, + expected=np.right_shift(lhs, rhs)) def testNumericOps(self): for dtype in self.numeric_types: @@ -228,37 +268,38 @@ class BinaryOpsTest(XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) - self._testBinary( - math_ops.maximum, - np.array([1, 2], dtype=dtype), - np.array([10, 20], dtype=dtype), - expected=np.array([10, 20], dtype=dtype)) - self._testBinary( - math_ops.maximum, - dtype(5), - np.array([1, 20], dtype=dtype), - expected=np.array([5, 20], dtype=dtype)) - self._testBinary( - math_ops.maximum, - np.array([[10], [2]], dtype=dtype), - dtype(7), - expected=np.array([[10], [7]], dtype=dtype)) + if dtype not in self.complex_types: # min/max not supported for complex + self._testBinary( + math_ops.maximum, + np.array([1, 2], dtype=dtype), + np.array([10, 20], dtype=dtype), + expected=np.array([10, 20], dtype=dtype)) + self._testBinary( + math_ops.maximum, + dtype(5), + np.array([1, 20], dtype=dtype), + expected=np.array([5, 20], dtype=dtype)) + self._testBinary( + math_ops.maximum, + np.array([[10], [2]], dtype=dtype), + dtype(7), + expected=np.array([[10], [7]], dtype=dtype)) - self._testBinary( - math_ops.minimum, - np.array([1, 20], dtype=dtype), - np.array([10, 2], dtype=dtype), - expected=np.array([1, 2], dtype=dtype)) - self._testBinary( - math_ops.minimum, - dtype(5), - np.array([1, 20], dtype=dtype), - expected=np.array([1, 5], dtype=dtype)) - self._testBinary( - math_ops.minimum, - np.array([[10], [2]], dtype=dtype), - dtype(7), - expected=np.array([[7], [2]], dtype=dtype)) + self._testBinary( + math_ops.minimum, + np.array([1, 20], dtype=dtype), + np.array([10, 2], dtype=dtype), + expected=np.array([1, 2], dtype=dtype)) + self._testBinary( + math_ops.minimum, + dtype(5), + np.array([1, 20], dtype=dtype), + expected=np.array([1, 5], dtype=dtype)) + self._testBinary( + math_ops.minimum, + np.array([[10], [2]], dtype=dtype), + dtype(7), + expected=np.array([[7], [2]], dtype=dtype)) self._testBinary( math_ops.multiply, @@ -276,21 +317,23 @@ class BinaryOpsTest(XLATestCase): dtype(7), expected=np.array([[70], [14]], dtype=dtype)) - self._testBinary( - math_ops.squared_difference, - np.array([1, 2], dtype=dtype), - np.array([10, 20], dtype=dtype), - expected=np.array([81, 324], dtype=dtype)) - self._testBinary( - math_ops.squared_difference, - dtype(5), - np.array([1, 2], dtype=dtype), - expected=np.array([16, 9], dtype=dtype)) - self._testBinary( - math_ops.squared_difference, - np.array([[1], [2]], dtype=dtype), - dtype(7), - expected=np.array([[36], [25]], dtype=dtype)) + # Complex support for squared_difference is incidental, see b/68205550 + if dtype not in self.complex_types: + self._testBinary( + math_ops.squared_difference, + np.array([1, 2], dtype=dtype), + np.array([10, 20], dtype=dtype), + expected=np.array([81, 324], dtype=dtype)) + self._testBinary( + math_ops.squared_difference, + dtype(5), + np.array([1, 2], dtype=dtype), + expected=np.array([16, 9], dtype=dtype)) + self._testBinary( + math_ops.squared_difference, + np.array([[1], [2]], dtype=dtype), + dtype(7), + expected=np.array([[36], [25]], dtype=dtype)) self._testBinary( nn_ops.bias_add, @@ -303,6 +346,139 @@ class BinaryOpsTest(XLATestCase): np.array([2, -1], dtype=dtype), expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype)) + def testComplexOps(self): + for dtype in self.complex_types: + ctypes = {np.complex64: np.float32} + self._testBinary( + math_ops.complex, + np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]), + np.array([[[[2, -3], [0, 4]]]], dtype=ctypes[dtype]), + expected=np.array([[[[-1 + 2j, 2 - 3j], [2, 4j]]]], dtype=dtype)) + + self._testBinary( + lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001), + np.array( + [[[[-1 + 2j, 2.00009999 - 3j], [2 - 3j, 3 + 4.01j]]]], + dtype=dtype), + np.array( + [[[[-1.001 + 2j, 2 - 3j], [2 - 3.00009j, 3 + 4j]]]], dtype=dtype), + expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) + + self._testBinary( + gen_math_ops._real_div, + np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype), + np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype), + expected=np.array( + [ + 1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2, + float("inf") + ], + dtype=dtype)) + + # TODO(b/65408531): support+test pow for cplx + + lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) + rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) + self._testBinary( + gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs) + + self._testBinary( + gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) + + # TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow) + + self._testBinary( + gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) + + self._testBinary( + gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs)) + + def testComplexMath(self): + for dtype in self.complex_types: + self._testBinary( + math_ops.add, + np.array([1 + 3j, 2 + 7j], dtype=dtype), + np.array([10 - 4j, 20 + 17j], dtype=dtype), + expected=np.array([11 - 1j, 22 + 24j], dtype=dtype)) + self._testBinary( + math_ops.add, + dtype(5 - 7j), + np.array([1 + 2j, 2 + 4j], dtype=dtype), + expected=np.array([6 - 5j, 7 - 3j], dtype=dtype)) + self._testBinary( + math_ops.add, + np.array([[1 - 2j], [2 + 1j]], dtype=dtype), + dtype(7 + 5j), + expected=np.array([[8 + 3j], [9 + 6j]], dtype=dtype)) + + self._testBinary( + math_ops.subtract, + np.array([1 + 3j, 2 + 7j], dtype=dtype), + np.array([10 - 4j, 20 + 17j], dtype=dtype), + expected=np.array([-9 + 7j, -18 - 10j], dtype=dtype)) + self._testBinary( + math_ops.subtract, + dtype(5 - 7j), + np.array([1 + 2j, 2 + 4j], dtype=dtype), + expected=np.array([4 - 9j, 3 - 11j], dtype=dtype)) + self._testBinary( + math_ops.subtract, + np.array([[1 - 2j], [2 + 1j]], dtype=dtype), + dtype(7 + 5j), + expected=np.array([[-6 - 7j], [-5 - 4j]], dtype=dtype)) + + self._testBinary( + math_ops.multiply, + np.array([1 + 3j, 2 + 7j], dtype=dtype), + np.array([10 - 4j, 20 + 17j], dtype=dtype), + expected=np.array( + [(1 + 3j) * (10 - 4j), (2 + 7j) * (20 + 17j)], dtype=dtype)) + self._testBinary( + math_ops.multiply, + dtype(5 - 7j), + np.array([1 + 2j, 2 + 4j], dtype=dtype), + expected=np.array( + [(5 - 7j) * (1 + 2j), (5 - 7j) * (2 + 4j)], dtype=dtype)) + self._testBinary( + math_ops.multiply, + np.array([[1 - 2j], [2 + 1j]], dtype=dtype), + dtype(7 + 5j), + expected=np.array( + [[(7 + 5j) * (1 - 2j)], [(7 + 5j) * (2 + 1j)]], dtype=dtype)) + + self._testBinary( + math_ops.div, + np.array([8 - 1j, 2 + 16j], dtype=dtype), + np.array([2 + 4j, 4 - 8j], dtype=dtype), + expected=np.array( + [(8 - 1j) / (2 + 4j), (2 + 16j) / (4 - 8j)], dtype=dtype)) + self._testBinary( + math_ops.div, + dtype(1 + 2j), + np.array([2 + 4j, 4 - 8j], dtype=dtype), + expected=np.array( + [(1 + 2j) / (2 + 4j), (1 + 2j) / (4 - 8j)], dtype=dtype)) + self._testBinary( + math_ops.div, + np.array([2 + 4j, 4 - 8j], dtype=dtype), + dtype(1 + 2j), + expected=np.array( + [(2 + 4j) / (1 + 2j), (4 - 8j) / (1 + 2j)], dtype=dtype)) + + # TODO(b/68205550): math_ops.squared_difference shouldn't be supported. + + self._testBinary( + nn_ops.bias_add, + np.array([[1 + 2j, 2 + 7j], [3 - 5j, 4 + 2j]], dtype=dtype), + np.array([2 + 6j, -1 - 3j], dtype=dtype), + expected=np.array([[3 + 8j, 1 + 4j], [5 + 1j, 3 - 1j]], dtype=dtype)) + self._testBinary( + nn_ops.bias_add, + np.array([[[[1 + 4j, 2 - 1j], [3 + 7j, 4]]]], dtype=dtype), + np.array([2 + 1j, -1 + 2j], dtype=dtype), + expected=np.array( + [[[[3 + 5j, 1 + 1j], [5 + 8j, 3 + 2j]]]], dtype=dtype)) + def _testDivision(self, dtype): """Test cases for division operators.""" self._testBinary( @@ -321,18 +497,19 @@ class BinaryOpsTest(XLATestCase): dtype(2), expected=np.array([[5], [2]], dtype=dtype)) - self._testBinary( - gen_math_ops._floor_div, - np.array([3, 3, -1, -9, -8], dtype=dtype), - np.array([2, -2, 7, 2, -4], dtype=dtype), - expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) + if dtype not in self.complex_types: # floordiv unsupported for complex. + self._testBinary( + gen_math_ops._floor_div, + np.array([3, 3, -1, -9, -8], dtype=dtype), + np.array([2, -2, 7, 2, -4], dtype=dtype), + expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) def testIntDivision(self): for dtype in self.int_types: self._testDivision(dtype) def testFloatDivision(self): - for dtype in self.float_types: + for dtype in self.float_types + self.complex_types: self._testDivision(dtype) def _testRemainder(self, dtype): @@ -676,6 +853,20 @@ class BinaryOpsTest(XLATestCase): [0, 0, 0, 0, 0, 0]], dtype=dtype)) + self._testBinary( + lambda x, y: array_ops.pad(x, y, constant_values=7), + np.array( + [[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array( + [[0, 3], [2, 1]], dtype=np.int32), + expected=np.array( + [[7, 7, 1, 2, 3, 7], + [7, 7, 4, 5, 6, 7], + [7, 7, 7, 7, 7, 7], + [7, 7, 7, 7, 7, 7], + [7, 7, 7, 7, 7, 7]], + dtype=dtype)) + def testMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index a56c53de0fb5f76c94064e2bdc2f1a543a207b09..0528a5415d579a844e68403ace1bb8982a10a841 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -49,11 +49,15 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, backend_deps = [] backend_data = [] if backend == "cpu": - backend_args += ["--test_device=XLA_CPU", - "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"] + backend_args += [ + "--test_device=XLA_CPU", + "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" + ] elif backend == "gpu": - backend_args += ["--test_device=XLA_GPU", - "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"] + backend_args += [ + "--test_device=XLA_GPU", + "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" + ] backend_tags += ["requires-gpu-sm35"] elif backend in plugins: backend_args += ["--test_device=" + plugins[backend]["device"], diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index d2a4e4bbd49cd1a78d80163bdbf147c34a455e38..664c77f2000281e3be989665664c1be58d4dd1e5 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -24,9 +24,11 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import flags from tensorflow.python.platform import test -_TEST_TYPES = [dtypes.float32] +FLAGS = flags.FLAGS class GatherTest(xla_test.XLATestCase): @@ -42,7 +44,7 @@ class GatherTest(xla_test.XLATestCase): def testScalar1D(self): with self.test_session() as session, self.test_scope(): data = np.array([0, 1, 2, 3, 7, 5]) - for dtype in _TEST_TYPES: + for dtype in self.all_tf_types: for indices in 4, [1, 2, 2, 4, 5]: params_np = self._buildParams(data, dtype) params = array_ops.placeholder(dtype=dtype) @@ -56,7 +58,7 @@ class GatherTest(xla_test.XLATestCase): with self.test_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) - for dtype in _TEST_TYPES: + for dtype in self.all_tf_types: for axis in 0, 1, -1: params_np = self._buildParams(data, dtype) params = array_ops.placeholder(dtype=dtype) @@ -70,7 +72,7 @@ class GatherTest(xla_test.XLATestCase): with self.test_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) - for dtype in _TEST_TYPES: + for dtype in self.all_tf_types: for axis in 0, 1, -1: params_np = self._buildParams(data, dtype) params = array_ops.placeholder(dtype=dtype) @@ -81,11 +83,34 @@ class GatherTest(xla_test.XLATestCase): expected = np.take(params_np, [0, 1, 0, 2], axis=axis) self.assertAllEqual(expected, gather_val) + def testSimpleTwoD32_Int64Indices(self): + if np.int64 not in self.int_types: + return + + with self.test_session() as session, self.test_scope(): + data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], + [12, 13, 14]]) + # The indices must be in bounds for any axis. + indices_np = np.array([0, 1, 0, 2]) + for dtype in self.all_tf_types: + for axis in 0, 1, -1: + params_np = self._buildParams(data, dtype) + params = array_ops.placeholder(dtype=dtype) + indices = array_ops.placeholder(dtype=dtypes.int64) + gather_t = array_ops.gather(params, indices, axis=axis) + gather_val = session.run( + gather_t, feed_dict={ + params: params_np, + indices: indices_np + }) + expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + self.assertAllEqual(expected, gather_val) + def testHigherRank(self): - # Check that scalar and empty indices shapes work as well. + """Check that scalar and empty indices shapes work as well.""" shape = (2, 1, 3, 2) for indices_shape in (), (0,), (2, 0), (2, 3): - for dtype in _TEST_TYPES: + for dtype in self.all_tf_types: for axis in 0, 1, 2, 3, -1, -2: params = self._buildParams(np.random.randn(*shape), dtype) indices = np.random.randint(shape[axis], size=indices_shape) @@ -98,5 +123,67 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual(gather_np, gather_value) -if __name__ == "__main__": +class GatherBenchmark(test.Benchmark): + """Microbenchmarks for the gather op.""" + + def _benchmarkGather(self, name, axis, gather_indices, use_xla_jit): + + def BuilderFn(): + inputs = variables.Variable( + array_ops.zeros([100, 100, 10, 100, 50], dtype=dtypes.float32), + dtype=dtypes.float32, + name='input') + indices = variables.Variable( + gather_indices, dtype=dtypes.int32, name='indices') + gather_t = array_ops.gather(inputs, indices, axis=axis) + return '%s.axis%d' % (name, axis), [gather_t] + + xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu') + + def _benchmarkSliceGather(self, axis, use_xla_jit): + """Benchmarks a gather op that's really a dynamic slice.""" + self._benchmarkGather('slice_gather', axis, [1], use_xla_jit) + + def _benchmarkNontrivialGather(self, axis, use_xla_jit): + self._benchmarkGather('nontrivial_gather', axis, [9, 1, 0, 2] * 4, + use_xla_jit) + + def benchmarkSliceGatherAxis0(self): + self._benchmarkSliceGather(axis=0, use_xla_jit=False) + + def benchmarkSliceGatherAxis0XLA(self): + self._benchmarkSliceGather(axis=0, use_xla_jit=True) + + def benchmarkSliceGatherAxis1(self): + self._benchmarkSliceGather(axis=1, use_xla_jit=False) + + def benchmarkSliceGatherAxis1XLA(self): + self._benchmarkSliceGather(axis=1, use_xla_jit=True) + + def benchmarkSliceGatherAxis4(self): + self._benchmarkSliceGather(axis=4, use_xla_jit=False) + + def benchmarkSliceGatherAxis4XLA(self): + self._benchmarkSliceGather(axis=4, use_xla_jit=True) + + def benchmarkNontrivialGatherAxis0(self): + self._benchmarkNontrivialGather(axis=0, use_xla_jit=False) + + def benchmarkNontrivialGatherAxis0XLA(self): + self._benchmarkNontrivialGather(axis=0, use_xla_jit=True) + + def benchmarkNontrivialGatherAxis1(self): + self._benchmarkNontrivialGather(axis=1, use_xla_jit=False) + + def benchmarkNontrivialGatherAxis1XLA(self): + self._benchmarkNontrivialGather(axis=1, use_xla_jit=True) + + def benchmarkNontrivialGatherAxis4(self): + self._benchmarkNontrivialGather(axis=4, use_xla_jit=False) + + def benchmarkNontrivialGatherAxis4XLA(self): + self._benchmarkNontrivialGather(axis=4, use_xla_jit=True) + + +if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 11914080eccbf3506e6a17e243bf9f8ba1cbb812..2d8236e2cbdfafb35626cd582ee39b1f917aec7f 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -21,15 +21,12 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.compiler import jit -from tensorflow.core.framework import function_pb2 -from tensorflow.core.framework import node_def_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl @@ -118,31 +115,13 @@ class JitLaunchTest(test.TestCase): def testNoOutputs(self): with session_lib.Session() as sess: - # Build a function with a single Const node, whose output is ignored. - fdef = function_pb2.FunctionDef() - fdef.signature.name = "KernelWithNoOutputs" - node = node_def_pb2.NodeDef() - node.op = "Const" - node.name = "ignored" - node.attr["dtype"].type = dtypes.int32.as_datatype_enum - tensor = tensor_util.make_tensor_proto([0], dtype=dtypes.int32, shape=[]) - node.attr["value"].tensor.CopyFrom(tensor) - fdef.node_def.extend([node]) # Check that calling the result as a compiled kernel doesn't crash. @function.Defun(compiled=True) def KernelWithNoOutputs(): - return constant_op.constant(100) - - # Hack to override the definition. By accessing .definition, we - # force the _DefinedFunction initialized internally. Then, we - # replace it's internal FunctionDef proto. We do this hack here - # because one typically can't construct KernelWithNoOutputs - # function via Defun decorator directly. - _ = KernelWithNoOutputs.definition - foo = KernelWithNoOutputs - foo._definition = fdef - call = KernelWithNoOutputs() + a = constant_op.constant(100) # pylint: disable=unused-variable + + call = KernelWithNoOutputs() # pylint: disable=assignment-from-no-return sess.run(call, {}) def testAliasing(self): diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index ae60d78f1a8dd898c5428a82be2196b52d4638d8..e4843b169b943b63346b783ddc50039030988ca5 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -68,6 +68,26 @@ class NAryOpsTest(XLATestCase): np.array([42], dtype=np.float32)], expected=np.array([48], dtype=np.float32)) + def testComplex(self): + for dtype in self.complex_types: + self._testNAry( + math_ops.add_n, [np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)], + expected=np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)) + + self._testNAry( + math_ops.add_n, [ + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.array([10j, 20], dtype=dtype) + ], + expected=np.array([1 + 12j, 22 - 3j], dtype=dtype)) + self._testNAry( + math_ops.add_n, [ + np.array([-4, 5j], dtype=dtype), + np.array([2 + 10j, -2], dtype=dtype), + np.array([42j, 3 + 3j], dtype=dtype) + ], + expected=np.array([-2 + 52j, 1 + 8j], dtype=dtype)) + @unittest.skip("IdentityN is temporarily CompilationOnly as workaround") def testIdentityN(self): self._testNAryLists(array_ops.identity_n, diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index a17a3f3d6536eea780106d84bcf4ce92c0fd017e..d6c93088d4efff7d8306e262a79ae49d3d8ac722 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -29,6 +29,9 @@ from tensorflow.python.platform import googletest class RandomOpsTest(XLATestCase): """Test cases for random-number generating operators.""" + def _random_types(self): + return set(self.numeric_types) - set(self.complex_types) + def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. with self.test_session() as sess: @@ -51,7 +54,8 @@ class RandomOpsTest(XLATestCase): def rng(dtype): return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) - for dtype in self.numeric_types: + + for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) def testRandomNormalIsNotConstant(self): @@ -63,7 +67,7 @@ class RandomOpsTest(XLATestCase): self._testRngIsNotConstant(rng, dtype) def testRandomUniformIsInRange(self): - for dtype in self.numeric_types: + for dtype in self._random_types(): with self.test_session() as sess: with self.test_scope(): x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2, diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 7e307f16afcedeb39b0a394bad8412aac565321a..c8a32f9e29ee5582ea69a9adce813c4250325226 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -75,7 +75,7 @@ namespace { // Command line flags: see main() below. int64 tf_xla_random_seed = 0; int32 tf_xla_test_repetitions = 20; -int64 tf_xla_max_tensor_size = 100000LL; +int64 tf_xla_max_tensor_size = 10000LL; string* tf_xla_test_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; @@ -83,8 +83,8 @@ string LocalDeviceToFullDeviceName(const string& device) { return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); } -constexpr std::array kAllXlaTypes = { - {DT_INT32, DT_FLOAT, DT_BOOL}}; +constexpr std::array kAllXlaTypes = { + {DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}}; // An OpTestBuilder is a graph builder class that takes as input an operator to // test, its inputs and attributes, and builds a graph that executes the @@ -367,11 +367,11 @@ OpTest::OpTest() { void OpTest::Repeatedly(const std::function& fn) { int const max_repetitions = tf_xla_test_repetitions; int valid_test_runs = 0; - // We run up to 20 * max_repetitions times; the idea is that if we roll the + // We run up to 100 * max_repetitions times; the idea is that if we roll the // dice enough times we will find some valid parameters. We want to put an // upper limit on the number iterations just in case the probability of // finding feasible parameters is very low. - for (int i = 0; !HasFailure() && i < max_repetitions * 20 && + for (int i = 0; !HasFailure() && i < max_repetitions * 100 && valid_test_runs < max_repetitions; ++i) { TestResult result = fn(); @@ -449,6 +449,13 @@ Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice shape) { }); break; } + case DT_COMPLEX64: { + std::uniform_real_distribution distribution(-1.0f, 1.0f); + test::FillFn(&tensor, [this, &distribution](int i) { + return complex64(distribution(generator()), distribution(generator())); + }); + break; + } case DT_INT32: { std::uniform_int_distribution distribution(-(1 << 20), 1 << 20); test::FillFn(&tensor, [this, &distribution](int i) -> int32 { @@ -624,11 +631,47 @@ std::vector OpTest::AsInt32s(const std::vector& int64s) { // Functions for comparing tensors. +template +double Abs(T x) { + return std::fabs(x); +} + +template <> +double Abs(complex64 x) { + return std::abs(x); +} + template bool IsClose(const T& x, const T& y, double atol, double rtol) { if (std::isnan(x) && std::isnan(y)) return true; if (x == y) return true; // Allow inf == inf. - return fabs(x - y) < atol + rtol * fabs(x); + return Abs(x - y) < atol + rtol * Abs(x); +} + +template <> +bool IsClose(const complex64& x, const complex64& y, double atol, + double rtol) { + if (std::isnan(x.real()) && std::isnan(y.real())) { + if (std::isnan(x.imag()) && std::isnan(y.imag())) { + return true; + } + if (x.imag() == y.imag()) return true; // Allow inf == inf. + return Abs(x.imag() - y.imag()) < atol + rtol * Abs(x.imag()); + } else if (std::isnan(x.imag()) && std::isnan(y.imag())) { + if (x.real() == y.real()) return true; // Allow inf == inf. + return Abs(x.real() - y.real()) < atol + rtol * Abs(x.real()); + } + if (x == y) return true; // Allow inf == inf. + return Abs(x - y) < atol + rtol * Abs(x); +} + +template +string Str(T x) { + return strings::StrCat(x); +} +template <> +string Str(complex64 x) { + return strings::StrCat("(", x.real(), ", ", x.imag(), ")"); } template @@ -639,9 +682,10 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, for (int i = 0; i < Tx.size(); ++i) { if (!IsClose(Tx(i), Ty(i), atol, rtol)) { return errors::InvalidArgument(strings::StrCat( - i, "-th tensor element isn't close: ", Tx(i), " vs. ", Ty(i), - ". x = ", x.DebugString(), "y = ", y.DebugString(), "atol = ", atol, - " rtol = ", rtol, " tol = ", atol + rtol * std::fabs(Tx(i)))); + i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ", + Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(), + "atol = ", atol, " rtol = ", rtol, + " tol = ", atol + rtol * Abs(Tx(i)))); } } return Status::OK(); @@ -683,6 +727,8 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, return TensorsAreCloseImpl(a, b, atol, rtol); case DT_DOUBLE: return TensorsAreCloseImpl(a, b, atol, rtol); + case DT_COMPLEX64: + return TensorsAreCloseImpl(a, b, atol, rtol); case DT_INT32: return TensorsAreEqualImpl(a, b); case DT_INT64: @@ -822,7 +868,7 @@ Tensor AsIntTensor(DataType dtype, const std::vector& values) { TEST_F(OpTest, Abs) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Abs").RandomInput(type).Attr("T", type)); }); @@ -837,7 +883,7 @@ TEST_F(OpTest, Acosh) { TEST_F(OpTest, Add) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add") .RandomInput(type, dims.first) @@ -848,7 +894,7 @@ TEST_F(OpTest, Add) { TEST_F(OpTest, AddN) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); int n = std::uniform_int_distribution(1, 5)(generator()); auto shape = RandomDims(); @@ -875,6 +921,14 @@ TEST_F(OpTest, All) { }); } +TEST_F(OpTest, Angle) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Angle") + .RandomInput(DT_COMPLEX64) + .Attr("T", DT_COMPLEX64)); + }); +} + TEST_F(OpTest, Any) { Repeatedly([this]() { std::vector data_dims = RandomDims(); @@ -889,17 +943,18 @@ TEST_F(OpTest, Any) { TEST_F(OpTest, ApproximateEqual) { Repeatedly([this]() { - auto dims = RandomDims(); + auto dims = BroadcastableDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, ArgMax) { Repeatedly([this]() { - std::vector dims = RandomDims(1, 5); + std::vector dims = RandomDims(1, 5, 1); int num_dims = dims.size(); int reduce_dim = std::uniform_int_distribution(-num_dims, num_dims)(generator()); @@ -943,6 +998,16 @@ TEST_F(OpTest, Atanh) { }); } +TEST_F(OpTest, Atan2) { + Repeatedly([this]() { + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Atan2") + .RandomInput(DT_FLOAT, dims.first) + .RandomInput(DT_FLOAT, dims.second) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, AvgPool) { Repeatedly([this]() { std::uniform_int_distribution random_int(1, 5); @@ -1038,6 +1103,7 @@ TEST_F(OpTest, AvgPool3DGrad) { TEST_F(OpTest, BatchMatMul) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); std::vector output_dims = RandomDims(2, 5, 0, 7); int64 ndims = output_dims.size(); int64 inner_dim = RandomDim(); @@ -1056,9 +1122,9 @@ TEST_F(OpTest, BatchMatMul) { } return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") - .RandomInput(DT_FLOAT, x_dims) - .RandomInput(DT_FLOAT, y_dims) - .Attr("T", DT_FLOAT) + .RandomInput(type, x_dims) + .RandomInput(type, y_dims) + .Attr("T", type) .Attr("adj_x", adj_x) .Attr("adj_y", adj_y)); }); @@ -1090,10 +1156,11 @@ TEST_F(OpTest, BatchToSpace) { CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), TensorShape({num_block_dims, 2}))); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace") - .RandomInput(DT_FLOAT, input_dims) + .RandomInput(type, input_dims) .Input(crops) - .Attr("T", DT_FLOAT) + .Attr("T", type) .Attr("block_size", block_size)); }); } @@ -1127,13 +1194,14 @@ TEST_F(OpTest, BatchToSpaceND) { CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), TensorShape({num_block_dims, 2}))); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("BatchToSpaceND") - .RandomInput(DT_FLOAT, input_dims) + .RandomInput(type, input_dims) .Input(test::AsTensor( std::vector(block_dims.begin(), block_dims.end()))) .Input(crops) - .Attr("T", DT_FLOAT)); + .Attr("T", type)); }); } @@ -1142,18 +1210,20 @@ TEST_F(OpTest, BiasAdd) { auto x_dims = RandomDims(2, kDefaultMaxRank); auto y_dims = {x_dims[x_dims.size() - 1]}; // TODO(phawkins): test both data formats. + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd") - .RandomInput(DT_FLOAT, x_dims) - .RandomInput(DT_FLOAT, y_dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, x_dims) + .RandomInput(type, y_dims) + .Attr("T", type)); }); } TEST_F(OpTest, BiasAddGrad) { Repeatedly([this]() { // TODO(phawkins): test both data formats. + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("BiasAddGrad").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("BiasAddGrad").RandomInput(type).Attr("T", type)); }); } @@ -1161,17 +1231,40 @@ TEST_F(OpTest, BiasAddV1) { Repeatedly([this]() { auto x_dims = RandomDims(2, kDefaultMaxRank); auto y_dims = {x_dims[x_dims.size() - 1]}; + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1") - .RandomInput(DT_FLOAT, x_dims) - .RandomInput(DT_FLOAT, y_dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, x_dims) + .RandomInput(type, y_dims) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, BitwiseAnd) { + Repeatedly([this]() { + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseAnd") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, BitwiseOr) { + Repeatedly([this]() { + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseOr") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, BroadcastArgs) { Repeatedly([this]() { // TODO(phawkins): only int32 seems to be implemented in Tensorflow. - // DataType type = Choose({DT_INT32, DT_INT64}); + // auto type = Choose({DT_INT32, DT_INT64}); DataType type = DT_INT32; auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose( @@ -1185,7 +1278,7 @@ TEST_F(OpTest, BroadcastArgs) { TEST_F(OpTest, BroadcastGradientArgs) { Repeatedly([this]() { // TODO(phawkins): only int32 seems to be implemented in Tensorflow. - // DataType type = Choose({DT_INT32, DT_INT64}); + // auto type = Choose({DT_INT32, DT_INT64}); DataType type = DT_INT32; auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose( @@ -1199,8 +1292,8 @@ TEST_F(OpTest, BroadcastGradientArgs) { TEST_F(OpTest, Cast) { Repeatedly([this]() { DataType src_type, dst_type; - src_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL}); - dst_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL}); + src_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}); + dst_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast") .RandomInput(src_type) .Attr("SrcT", src_type) @@ -1215,9 +1308,19 @@ TEST_F(OpTest, Ceil) { }); } +TEST_F(OpTest, Complex) { + Repeatedly([this]() { + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Complex") + .RandomInput(DT_FLOAT, dims.first) + .RandomInput(DT_FLOAT, dims.second) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Concat) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); int n = std::uniform_int_distribution(2, 5)(generator()); std::vector dims = RandomDims(1); @@ -1257,6 +1360,14 @@ TEST_F(OpTest, ConcatOffset) { }); } +TEST_F(OpTest, Conj) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Conj") + .RandomInput(DT_COMPLEX64) + .Attr("T", DT_COMPLEX64)); + }); +} + TEST_F(OpTest, Conv2D) { Repeatedly([this]() { WindowedSpatialDims d = ChooseWindowedSpatialDims(2); @@ -1271,11 +1382,12 @@ TEST_F(OpTest, Conv2D) { std::vector kernel_dims = {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}; + DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2D") - .RandomInput(DT_FLOAT, data_dims) - .RandomInput(DT_FLOAT, kernel_dims) - .Attr("T", DT_FLOAT) + .RandomInput(type, data_dims) + .RandomInput(type, kernel_dims) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); @@ -1295,12 +1407,13 @@ TEST_F(OpTest, Conv2DBackpropFilter) { ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); Tensor kernel_shape = test::AsTensor(AsInt32s( {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out})); + DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropFilter") - .RandomInput(DT_FLOAT, activations) + .RandomInput(type, activations) .Input(kernel_shape) - .RandomInput(DT_FLOAT, backprop) - .Attr("T", DT_FLOAT) + .RandomInput(type, backprop) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); @@ -1320,12 +1433,13 @@ TEST_F(OpTest, Conv2DBackpropInput) { ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}; + DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropInput") .Input(in_shape) - .RandomInput(DT_FLOAT, kernel) - .RandomInput(DT_FLOAT, backprop) - .Attr("T", DT_FLOAT) + .RandomInput(type, kernel) + .RandomInput(type, backprop) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); @@ -1343,11 +1457,12 @@ TEST_F(OpTest, Conv3D) { std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out}; + DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3D") - .RandomInput(DT_FLOAT, data) - .RandomInput(DT_FLOAT, kernel) - .Attr("T", DT_FLOAT) + .RandomInput(type, data) + .RandomInput(type, kernel) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); }); @@ -1367,12 +1482,13 @@ TEST_F(OpTest, Conv3DBackpropFilter) { Tensor kernel_shape = test::AsTensor( AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out})); + DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3DBackpropFilterV2") - .RandomInput(DT_FLOAT, activations) + .RandomInput(type, activations) .Input(kernel_shape) - .RandomInput(DT_FLOAT, backprop) - .Attr("T", DT_FLOAT) + .RandomInput(type, backprop) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); }); @@ -1391,17 +1507,34 @@ TEST_F(OpTest, Conv3DBackpropInput) { ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out}; + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3DBackpropInputV2") .Input(in_shape) - .RandomInput(DT_FLOAT, kernel) - .RandomInput(DT_FLOAT, backprop) - .Attr("T", DT_FLOAT) + .RandomInput(type, kernel) + .RandomInput(type, backprop) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); }); } +TEST_F(OpTest, Cos) { + Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Cos").RandomInput(type).Attr("T", type)); + }); +} + +TEST_F(OpTest, Cosh) { + Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Cosh").RandomInput(type).Attr("T", type)); + }); +} + TEST_F(OpTest, DepthToSpace) { Repeatedly([this]() { int64 block = RandomDim(2, 5); @@ -1409,14 +1542,16 @@ TEST_F(OpTest, DepthToSpace) { input_dims[1] = (input_dims[1] + (block - 1)) / block; input_dims[2] = (input_dims[2] + (block - 1)) / block; input_dims[3] *= block * block; + auto type = Choose(kAllXlaTypes); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DepthToSpace") - .RandomInput(DT_FLOAT, input_dims) - .Attr("T", DT_FLOAT) + .RandomInput(type, input_dims) + .Attr("T", type) .Attr("block_size", block)); }); } TEST_F(OpTest, DepthwiseConv2DNative) { + if (1) return; Repeatedly([this]() { WindowedSpatialDims d = ChooseWindowedSpatialDims(2); std::uniform_int_distribution random_int(1, 5); @@ -1427,17 +1562,20 @@ TEST_F(OpTest, DepthwiseConv2DNative) { std::vector kernel_dims = {d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}; + std::vector strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims); + strides[2] = strides[1]; // Current impl only supports equal strides return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("DepthwiseConv2dNative") .RandomInput(DT_FLOAT, input_dims) .RandomInput(DT_FLOAT, kernel_dims) .Attr("T", DT_FLOAT) - .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) + .Attr("strides", strides) .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); }); } TEST_F(OpTest, DepthwiseConv2DBackpropFilter) { + if (1) return; Repeatedly([this]() { WindowedSpatialDims d = ChooseWindowedSpatialDims(2); std::uniform_int_distribution random_int(1, 5); @@ -1450,33 +1588,22 @@ TEST_F(OpTest, DepthwiseConv2DBackpropFilter) { FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims); Tensor kernel_shape = test::AsTensor(AsInt32s( {d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier})); + std::vector strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims); + strides[2] = strides[1]; // Current impl only supports equal strides return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("DepthwiseConv2dNativeBackpropFilter") .RandomInput(DT_FLOAT, activations) .Input(kernel_shape) .RandomInput(DT_FLOAT, backprop) .Attr("T", DT_FLOAT) - .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) + .Attr("strides", strides) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); }); } -TEST_F(OpTest, Cos) { - Repeatedly([this]() { - return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Cos").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); - }); -} - -TEST_F(OpTest, Cosh) { - Repeatedly([this]() { - return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Cosh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); - }); -} - TEST_F(OpTest, DepthwiseConv2DBackpropInput) { + if (1) return; Repeatedly([this]() { WindowedSpatialDims d = ChooseWindowedSpatialDims(2); std::uniform_int_distribution random_int(1, 5); @@ -1489,21 +1616,24 @@ TEST_F(OpTest, DepthwiseConv2DBackpropInput) { FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims); std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}; + std::vector strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims); + strides[2] = strides[1]; // Current impl only supports equal strides return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("DepthwiseConv2dNativeBackpropInput") .Input(in_shape) .RandomInput(DT_FLOAT, kernel) .RandomInput(DT_FLOAT, backprop) .Attr("T", DT_FLOAT) - .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) + .Attr("strides", strides) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); }); } TEST_F(OpTest, Diag) { + if (1) return; Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose(kAllXlaTypes); std::vector dims; // Diag causes a quadratic blowup in output size. int64 size; @@ -1518,7 +1648,7 @@ TEST_F(OpTest, Diag) { TEST_F(OpTest, DiagPart) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose(kAllXlaTypes); auto dims = RandomDims(1, 3); // Duplicate the random dims. std::vector doubled_dims(dims.size() * 2); @@ -1532,7 +1662,7 @@ TEST_F(OpTest, DiagPart) { TEST_F(OpTest, Div) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div") .RandomInput(type, dims.first) @@ -1543,7 +1673,7 @@ TEST_F(OpTest, Div) { TEST_F(OpTest, DynamicStitch) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); int n = std::uniform_int_distribution(2, 5)(generator()); OpTestBuilder builder("DynamicStitch"); builder.Attr("T", type); @@ -1628,7 +1758,7 @@ TEST_F(OpTest, SeluGrad) { TEST_F(OpTest, Equal) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal") .RandomInput(type, dims.first) @@ -1639,21 +1769,23 @@ TEST_F(OpTest, Equal) { TEST_F(OpTest, Exp) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Exp").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Exp").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Expm1) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Expm1").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Expm1").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, ExpandDims) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector in_dims = RandomDims(); Tensor dim(DT_INT32, TensorShape()); std::uniform_int_distribution d(-1 - in_dims.size(), in_dims.size()); @@ -1667,7 +1799,7 @@ TEST_F(OpTest, ExpandDims) { TEST_F(OpTest, Fill) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(); std::vector shape(dims.begin(), dims.end()); return ExpectTfAndXlaOutputsAreClose( @@ -1698,7 +1830,7 @@ TEST_F(OpTest, FloorDiv) { TEST_F(OpTest, FloorMod) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod") .RandomInput(type, dims.first) @@ -1709,7 +1841,7 @@ TEST_F(OpTest, FloorMod) { TEST_F(OpTest, Greater) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater") .RandomInput(type, dims.first) @@ -1720,7 +1852,7 @@ TEST_F(OpTest, Greater) { TEST_F(OpTest, GreaterEqual) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual") .RandomInput(type, dims.first) @@ -1729,6 +1861,22 @@ TEST_F(OpTest, GreaterEqual) { }); } +TEST_F(OpTest, Imag) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Imag") + .RandomInput(DT_COMPLEX64) + .Attr("T", DT_COMPLEX64)); + }); +} + +TEST_F(OpTest, Invert) { + Repeatedly([this]() { + DataType type = DT_INT32; + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Invert").RandomInput(type).Attr("T", type)); + }); +} + TEST_F(OpTest, L2Loss) { Repeatedly([this]() { DataType type = DT_FLOAT; @@ -1739,7 +1887,7 @@ TEST_F(OpTest, L2Loss) { TEST_F(OpTest, Less) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less") .RandomInput(type, dims.first) @@ -1750,7 +1898,7 @@ TEST_F(OpTest, Less) { TEST_F(OpTest, LessEqual) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual") .RandomInput(type, dims.first) @@ -1766,7 +1914,7 @@ TEST_F(OpTest, LinSpace) { return test::AsScalar(x); }; std::uniform_int_distribution distribution(-50, 50); - DataType type = Choose({DT_INT32, DT_INT64}); + auto type = Choose({DT_INT32, DT_INT64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("LinSpace") .RandomInput(DT_FLOAT, {}) @@ -1779,15 +1927,17 @@ TEST_F(OpTest, LinSpace) { TEST_F(OpTest, Log) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Log").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Log").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Log1p) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Log1p").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Log1p").RandomInput(type).Attr("T", DT_FLOAT)); }); } @@ -1884,10 +2034,11 @@ TEST_F(OpTest, MatMul) { std::swap(b_dims[0], b_dims[1]); } + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") - .RandomInput(DT_FLOAT, a_dims) - .RandomInput(DT_FLOAT, b_dims) - .Attr("T", DT_FLOAT) + .RandomInput(type, a_dims) + .RandomInput(type, b_dims) + .Attr("T", type) .Attr("transpose_a", transpose_a) .Attr("transpose_b", transpose_b)); }); @@ -1895,7 +2046,7 @@ TEST_F(OpTest, MatMul) { TEST_F(OpTest, MatrixDiag) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag") .RandomInput(type, RandomDims(1)) .Attr("T", type)); @@ -1904,7 +2055,7 @@ TEST_F(OpTest, MatrixDiag) { TEST_F(OpTest, MatrixDiagPart) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart") .RandomInput(type, RandomDims(2)) .Attr("T", type)); @@ -1913,7 +2064,7 @@ TEST_F(OpTest, MatrixDiagPart) { TEST_F(OpTest, Max) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); std::vector data_dims = RandomDims(); Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); @@ -1927,7 +2078,7 @@ TEST_F(OpTest, Max) { TEST_F(OpTest, Maximum) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum") .RandomInput(type, dims.first) @@ -1995,7 +2146,7 @@ TEST_F(OpTest, MaxPool3D) { TEST_F(OpTest, Mean) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); // TODO(phawkins): CPU and XLA differ output for reducing across a // size-0 dimension (nan vs 0). For now, require size >= 1. std::vector data_dims = RandomDims(0, kDefaultMaxRank, 1); @@ -2011,7 +2162,7 @@ TEST_F(OpTest, Mean) { TEST_F(OpTest, Min) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); std::vector data_dims = RandomDims(); Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); @@ -2025,7 +2176,7 @@ TEST_F(OpTest, Min) { TEST_F(OpTest, Minimum) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum") .RandomInput(type, dims.first) @@ -2046,7 +2197,7 @@ TEST_F(OpTest, Mod) { TEST_F(OpTest, Mul) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul") .RandomInput(type, dims.first) @@ -2057,7 +2208,7 @@ TEST_F(OpTest, Mul) { TEST_F(OpTest, Neg) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Neg").RandomInput(type).Attr("T", type)); }); @@ -2065,7 +2216,7 @@ TEST_F(OpTest, Neg) { TEST_F(OpTest, NotEqual) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual") .RandomInput(type, dims.first) @@ -2076,7 +2227,7 @@ TEST_F(OpTest, NotEqual) { TEST_F(OpTest, OneHot) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(); int num_dims = dims.size(); @@ -2106,7 +2257,7 @@ TEST_F(OpTest, OneHot) { TEST_F(OpTest, OnesLike) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type)); }); @@ -2114,7 +2265,7 @@ TEST_F(OpTest, OnesLike) { TEST_F(OpTest, Pack) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); int n = std::uniform_int_distribution(1, 5)(generator()); std::vector dims = RandomDims(); @@ -2136,7 +2287,7 @@ TEST_F(OpTest, Pack) { // TODO(b/31741898): crashes on GPU. TEST_F(OpTest, Pad) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(); // TODO(b/31741996): re-enable DT_INT64 when bug is fixed. @@ -2165,16 +2316,17 @@ TEST_F(OpTest, Pow) { // nontermination. Repeatedly([this]() { auto dims = BroadcastableDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow") - .RandomInput(DT_FLOAT, dims.first) - .RandomInput(DT_FLOAT, dims.second) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, Prod) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); std::vector data_dims = RandomDims(); Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); @@ -2208,15 +2360,23 @@ TEST_F(OpTest, Range) { TEST_F(OpTest, Rank) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Rank").RandomInput(type).Attr("T", type)); }); } +TEST_F(OpTest, Real) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Real") + .RandomInput(DT_COMPLEX64) + .Attr("T", DT_COMPLEX64)); + }); +} + TEST_F(OpTest, RealDiv) { Repeatedly([this]() { - DataType type = DT_FLOAT; + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv") .RandomInput(type, dims.first) @@ -2227,18 +2387,20 @@ TEST_F(OpTest, RealDiv) { TEST_F(OpTest, Reciprocal) { Repeatedly([this]() { + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Reciprocal").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Reciprocal").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, ReciprocalGrad) { Repeatedly([this]() { std::vector dims = RandomDims(); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims) + .RandomInput(type, dims) + .Attr("T", type)); }); } TEST_F(OpTest, Relu) { @@ -2277,7 +2439,7 @@ TEST_F(OpTest, ReluGrad) { TEST_F(OpTest, Reshape) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(); std::bernoulli_distribution random_bool; std::vector dims_before, dims_after; @@ -2305,24 +2467,24 @@ TEST_F(OpTest, Reshape) { TEST_F(OpTest, Reverse) { Repeatedly([this]() { std::vector dims = RandomDims(1); - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose(kAllXlaTypes); int64 rank = dims.size(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse") .RandomInput(type, dims) .RandomInput(DT_BOOL, {rank}) - .Attr("T", DT_FLOAT)); + .Attr("T", type)); }); } TEST_F(OpTest, ReverseV2) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); Tensor indices = RandomReductionIndices(data_dims.size()); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2") .RandomInput(type, data_dims) .Input(indices) - .Attr("T", DT_FLOAT)); + .Attr("T", type)); }); } @@ -2342,24 +2504,26 @@ TEST_F(OpTest, Round) { TEST_F(OpTest, Rsqrt) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Rsqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Rsqrt").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, RsqrtGrad) { Repeatedly([this]() { auto dims = RandomDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims) + .RandomInput(type, dims) + .Attr("T", type)); }); } TEST_F(OpTest, Shape) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Shape").RandomInput(type).Attr("T", type)); }); @@ -2367,7 +2531,7 @@ TEST_F(OpTest, Shape) { TEST_F(OpTest, ShapeN) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); int n = std::uniform_int_distribution(1, 5)(generator()); OpTestBuilder builder("ShapeN"); builder.Attr("T", type); @@ -2381,24 +2545,26 @@ TEST_F(OpTest, ShapeN) { TEST_F(OpTest, Sigmoid) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sigmoid").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Sigmoid").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, SigmoidGrad) { Repeatedly([this]() { auto dims = RandomDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims) + .RandomInput(type, dims) + .Attr("T", type)); }); } TEST_F(OpTest, Sign) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Sign").RandomInput(type).Attr("T", type)); }); @@ -2406,21 +2572,23 @@ TEST_F(OpTest, Sign) { TEST_F(OpTest, Sin) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sin").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Sin").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Sinh) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sinh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Sinh").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Size) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose(kAllXlaTypes); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Size").RandomInput(type).Attr("T", type)); }); @@ -2428,7 +2596,7 @@ TEST_F(OpTest, Size) { TEST_F(OpTest, Slice) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); std::vector begin(data_dims.size()), size(data_dims.size()); @@ -2532,10 +2700,11 @@ TEST_F(OpTest, SpaceToBatch) { CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), TensorShape({num_block_dims, 2}))); + auto type = Choose(kAllXlaTypes); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch") - .RandomInput(DT_FLOAT, input_dims) + .RandomInput(type, input_dims) .Input(paddings) - .Attr("T", DT_FLOAT) + .Attr("T", type) .Attr("block_size", block_size)); }); } @@ -2573,13 +2742,14 @@ TEST_F(OpTest, SpaceToBatchND) { CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), TensorShape({num_block_dims, 2}))); + auto type = Choose(kAllXlaTypes); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("SpaceToBatchND") - .RandomInput(DT_FLOAT, input_dims) + .RandomInput(type, input_dims) .Input(test::AsTensor( std::vector(block_dims.begin(), block_dims.end()))) .Input(paddings) - .Attr("T", DT_FLOAT)); + .Attr("T", type)); }); } @@ -2649,7 +2819,7 @@ TEST_F(OpTest, SparseSoftmaxCrossEntropyWithLogits) { TEST_F(OpTest, Split) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(1); std::uniform_int_distribution ud; int32 dim = std::uniform_int_distribution( @@ -2669,18 +2839,20 @@ TEST_F(OpTest, Split) { TEST_F(OpTest, Sqrt) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Sqrt").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, SqrtGrad) { Repeatedly([this]() { auto dims = RandomDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims) + .RandomInput(type, dims) + .Attr("T", type)); }); } @@ -2696,7 +2868,7 @@ TEST_F(OpTest, SquaredDifference) { TEST_F(OpTest, Square) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Square").RandomInput(type).Attr("T", type)); }); @@ -2704,7 +2876,7 @@ TEST_F(OpTest, Square) { TEST_F(OpTest, Squeeze) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(0, kDefaultMaxRank, 0, 5); std::bernoulli_distribution random_bool; std::vector squeeze_dims; @@ -2722,7 +2894,7 @@ TEST_F(OpTest, Squeeze) { TEST_F(OpTest, Sub) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub") .RandomInput(type, dims.first) @@ -2733,7 +2905,7 @@ TEST_F(OpTest, Sub) { TEST_F(OpTest, Sum) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); std::vector data_dims = RandomDims(); Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); @@ -2747,7 +2919,7 @@ TEST_F(OpTest, Sum) { TEST_F(OpTest, StridedSlice) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); std::vector begin(data_dims.size()), end(data_dims.size()); std::vector strides(data_dims.size()); @@ -2792,7 +2964,7 @@ TEST_F(OpTest, StridedSlice) { TEST_F(OpTest, StridedSliceGrad) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); // Dimensions of the forward input. std::vector dims = RandomDims(); @@ -2845,31 +3017,34 @@ TEST_F(OpTest, StridedSliceGrad) { TEST_F(OpTest, Tan) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Tan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Tan").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Tanh) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Tanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Tanh").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, TanhGrad) { Repeatedly([this]() { auto dims = RandomDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims) + .RandomInput(type, dims) + .Attr("T", type)); }); } TEST_F(OpTest, Tile) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(1); std::vector multiples(t_dims.size()); for (int i = 0; i < t_dims.size(); ++i) { @@ -2885,7 +3060,7 @@ TEST_F(OpTest, Tile) { TEST_F(OpTest, Transpose) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); std::vector perm(data_dims.size()); std::iota(perm.begin(), perm.end(), 0); @@ -2910,7 +3085,7 @@ TEST_F(OpTest, TruncateDiv) { TEST_F(OpTest, TruncateMod) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod") .RandomInput(type, dims.first) @@ -2921,7 +3096,7 @@ TEST_F(OpTest, TruncateMod) { TEST_F(OpTest, ZerosLike) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type)); }); diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 6f19834160d5b430acab52f06ef7837ba276d4a2..76644380bdf2e0c24f6d363ddfaabdff836495d7 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -26,6 +26,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -327,6 +328,138 @@ class UnaryOpsTest(XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype)) + def testComplexOps(self): + for dtype in self.complex_types: + # TODO(b/65408531): math_ops.acosh (needs pow) + # TODO(b/65408531): math_ops.asinh (needs pow) + + # TODO(b/65408531): Wider support for log (needs atan2). + atan2_supported = self.device == "XLA_GPU" + if atan2_supported: + self._assertOpOutputMatchesExpected( + math_ops.atanh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arctanh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.cosh, + np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype), + expected=np.cosh(np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.sinh, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.sinh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.exp, + np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype), + expected=np.exp(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.expm1, + np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype), + expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.reciprocal, + np.array([[1, 2j, 2 + 3j]], dtype=dtype), + expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype)) + + if atan2_supported: + self._assertOpOutputMatchesExpected( + math_ops.log, + np.array([[5j, 3 - 2j]], dtype=dtype), + expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.sin, + np.array([[5j, 3 - 2j]], dtype=dtype), + expected=np.sin(np.array([[5j, 3 - 2j]], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.cos, + np.array([[5j, 3 - 2j]], dtype=dtype), + expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype))) + + # TODO(b/34703906): improve log1p implementation and make tolerance + # tighter. + if atan2_supported: # TODO(b/34703906): log support + self._assertOpOutputMatchesExpected( + math_ops.log1p, + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), + expected=np.log1p( + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) + + # TODO(b/34703906): math_ops.rsqrt (needs pow) + + # TODO(b/34703906): math_ops.sigmoid (needs tanh) + + # TODO(b/34703906): math_ops.sqrt (needs pow) + + self._assertOpOutputMatchesExpected( + math_ops.tan, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) + + # TODO(b/34703906): math_ops.tanh (as itself) + + ctypes = {np.complex64: np.float32} + self._assertOpOutputMatchesExpected( + math_ops.abs, + np.array([[3 - 4j, -1j, np.inf]], dtype=dtype), + expected=np.array([[5, 1, np.inf]], dtype=ctypes[dtype])) + + self._assertOpOutputMatchesExpected( + math_ops.negative, + np.array([[-1 + 2j, -3j]], dtype=dtype), + expected=np.array([[1 - 2j, 3j]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + math_ops.square, + np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype), + expected=np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype)**2) + + self._assertOpOutputMatchesExpected( + array_ops.zeros_like, + np.array([[4j, 3 - 2j], [2, -1j]], dtype=dtype), + expected=np.array([[0, 0], [0, 0]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + array_ops.ones_like, + np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype), + expected=np.array([[1, 1], [1, 1]], dtype=dtype)) + + if atan2_supported: # TODO(b/34703906): atan2 support + self._assertOpOutputMatchesExpected( + math_ops.angle, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.angle( + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.conj, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.array([1 - 3j, -4 - 7j, 2.7, 3j], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + math_ops.imag, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.array([3, 7, 0, -3], dtype=ctypes[dtype])) + + self._assertOpOutputMatchesExpected( + math_ops.real, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype])) + + def testIntOps(self): + for dtype in self.int_types: + self._assertOpOutputMatchesExpected( + bitwise_ops.invert, + np.array([0, -1, 1, 16, 42], dtype=dtype), + expected=np.array([-1, 0, -2, -17, -43], dtype=dtype)) + def testNumericOps(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( @@ -391,11 +524,14 @@ class UnaryOpsTest(XLATestCase): def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] - types = [dtypes.bool, dtypes.int32, dtypes.float32] + types = [dtypes.bool, dtypes.int32, dtypes.float32] + self.complex_tf_types for shape in shapes: for src_type in types: for dst_type in types: src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype) + if src_type in self.complex_tf_types: + src += (np.arange(np.prod(shape)) * 2j).astype( + src_type.as_numpy_dtype) src = src.reshape(shape) dst = src.astype(dst_type.as_numpy_dtype) @@ -558,5 +694,6 @@ class UnaryOpsTest(XLATestCase): log_eps + ten, -log_eps, -log_eps - one, -log_eps + one, -log_eps - ten, -log_eps + ten], dtype) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index fdf3f9fb6ada762751f8639af29bec0b0d9a8b01..c50342dee45eba6ae54f01653ecc81ef096b547b 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -43,7 +43,7 @@ class VariableOpsTest(XLATestCase): # Regression test for a bug where computations with one non-constant # output and one variable update were mishandled. for dtype in self.numeric_types: - init = np.array([[1, 2], [3, 4]], dtype=dtype) + init = np.array([[1, 2j], [3, 4]]).astype(dtype) with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) @@ -51,82 +51,91 @@ class VariableOpsTest(XLATestCase): x = v.assign_add(p) with ops.control_dependencies([x]): y = v.read_value() - self.assertAllClose(np.array([[2, 3], [4, 5]], dtype=dtype), - sess.run(y, {p: 1})) + self.assertAllClose( + np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, { + p: 1 + })) def testSparseRead0DIndices(self): for dtype in self.numeric_types: - init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype) + init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8j, 9, 10, + 11]]).astype(dtype) with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read(2) - self.assertAllClose(np.array([8, 9, 10, 11], dtype=dtype), sess.run(x)) + self.assertAllClose( + np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x)) def testSparseRead1DIndices(self): for dtype in self.numeric_types: - init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype) + init = np.array([[0, 1, 2, 3], [4, 5, 6j, 7], [8, 9, 10, + 11]]).astype(dtype) with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read([2, 1]) self.assertAllClose( - np.array([[8, 9, 10, 11], [4, 5, 6, 7]], dtype=dtype), sess.run(x)) + np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype), + sess.run(x)) def testSparseRead2DIndices(self): for dtype in self.numeric_types: - init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype) + init = np.array([[0, 1, 2j, 3], [4, 5, 6, 7], [8, 9, 10, + 11]]).astype(dtype) with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read([[2, 1], [0, 2]]) self.assertAllClose( - np.array( - [[[8, 9, 10, 11], [4, 5, 6, 7]], [[0, 1, 2, 3], [8, 9, 10, - 11]]], - dtype=dtype), sess.run(x)) + np.array([[[8, 9, 10, 11], [4, 5, 6, 7]], + [[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype), + sess.run(x)) def testSparseRead2DIndices3DTensor(self): for dtype in self.numeric_types: - init = np.array( - [[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]], - [[20, 21, 22], [23, 24, 25]], [[30, 31, 32], [33, 34, 35]]], - dtype=dtype) + init = np.array([[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]], + [[20, 21, 22], [23, 24j, 25]], + [[30, 31, 32], [33, 34, 35]]]).astype(dtype) with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read([[2, 1], [3, 0]]) self.assertAllClose( np.array( - [[[[20, 21, 22], [23, 24, 25]], [[10, 11, 12], [13, 14, 15]]], + [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]], - dtype=dtype), sess.run(x)) + ).astype(dtype), sess.run(x)) def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" - with self.test_session() as session: - with self.test_scope(): - with variable_scope.variable_scope("ascope", use_resource=True): - x = variable_scope.get_variable( - "x", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(2)) - a = x.read_value() - with ops.control_dependencies([a]): - b = state_ops.assign(x, 47) - with ops.control_dependencies([b]): - c = x.read_value() - with ops.control_dependencies([c]): - d = state_ops.assign_add(x, 3) - with ops.control_dependencies([d]): - e = x.read_value() - - session.run(variables.global_variables_initializer()) - v1, v2, v3 = session.run([a, c, e]) - self.assertAllClose(2.0, v1) - self.assertAllClose(47.0, v2) - self.assertAllClose(50.0, v3) + for dtype in self.numeric_types: + with self.test_session() as session: + print(ops.get_default_graph()) + with self.test_scope(): + with variable_scope.variable_scope("ascope", use_resource=True): + x = variable_scope.get_variable( + "x", + shape=[], + dtype=dtype, + initializer=init_ops.constant_initializer(2)) + a = x.read_value() + with ops.control_dependencies([a]): + b = state_ops.assign(x, dtype(47)) + with ops.control_dependencies([b]): + c = x.read_value() + with ops.control_dependencies([c]): + d = state_ops.assign_add(x, np.array(6 + 2j).astype(dtype)) + with ops.control_dependencies([d]): + e = state_ops.assign_sub(x, dtype(3)) + with ops.control_dependencies([e]): + f = x.read_value() + + session.run(variables.global_variables_initializer()) + v1, v2, v3 = session.run([a, c, f]) + self.assertAllClose(dtype(2), v1) + self.assertAllClose(dtype(47), v2) + self.assertAllClose(np.array(50 + 2j).astype(dtype), v3) def testTraining(self): """Tests a gradient descent step for a simple model.""" diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index da6dc88f1fb07200799f8ee231fc04628b265e24..0be127997e5211f810ca791187486760881fe172 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -63,12 +63,19 @@ class XLATestCase(test.TestCase): self.float_tf_types = [ dtype for dtype in self.all_tf_types if dtype.is_floating ] - self.numeric_tf_types = self.int_tf_types + self.float_tf_types + self.complex_tf_types = [ + dtype for dtype in self.all_tf_types if dtype.is_complex + ] + self.numeric_tf_types = ( + self.int_tf_types + self.float_tf_types + self.complex_tf_types) self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types] self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types] self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types] - self.numeric_types = self.int_types + self.float_types + self.complex_types = [ + dtype.as_numpy_dtype for dtype in self.complex_tf_types + ] + self.numeric_types = self.int_types + self.float_types + self.complex_types # Parse the manifest file, if any, into a regex identifying tests to # disable diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 08f2249e0d767e7493dce7be55793083cd1cc6c8..3c94bcafc1d19b1bc54887e6f2c25b1886be646e 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -87,6 +87,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service/cpu:cpu_executable", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -102,11 +103,13 @@ cc_library( "xla_helpers.cc", "xla_op_kernel.cc", "xla_op_registry.cc", + "graph_compiler.cc", "xla_cpu_backend.cc", ] + if_cuda_is_configured([ "xla_gpu_backend.cc", ]), hdrs = [ + "graph_compiler.h", "xla_compilation_device.h", "xla_compiler.h", "xla_context.h", @@ -117,6 +120,7 @@ cc_library( visibility = [":friends"], deps = [ ":common", + ":const_analysis", ":dump_graph", ":functionalize_control_flow", "//tensorflow/compiler/xla:literal_util", @@ -224,7 +228,6 @@ tf_cc_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -253,6 +256,7 @@ tf_cc_test( "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -347,6 +351,7 @@ cc_library( hdrs = ["functionalize_control_flow.h"], deps = [ "//tensorflow/compiler/jit:graph_to_functiondef", + "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:functional_ops", "//tensorflow/compiler/xla:status_macros", @@ -354,6 +359,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:lib", ], ) @@ -371,6 +377,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/cc:functional_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:ops", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index bf75f85db041087d8770bd21494f8e1a7fe8c1b5..102a2cf07b51486bb445b0311966717b7e82ace6 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -67,6 +67,7 @@ Status BackwardsConstAnalysis(const Graph& g, {"Min", "reduction_indices"}, {"OneHot", "depth"}, {"Pad", "paddings"}, + {"PadV2", "paddings"}, {"MirrorPad", "paddings"}, {"Prod", "reduction_indices"}, {"RandomStandardNormal", "shape"}, diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 1c7a2046aa549beb2de58d21f517363d4fe8aea7..35b6960a98cda1bf098f3e01cac3df8173bdc729 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -17,15 +17,19 @@ limitations under the License. #include #include +#include #include #include #include "tensorflow/compiler/jit/graph_to_functiondef.h" +#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/lib/gtl/optional.h" namespace tensorflow { @@ -70,11 +74,24 @@ struct Frame { std::unordered_set nodes; }; +// Returns a textual representation of the names of the nodes in the input. +template +string NodesToString(const T& nodes) { + return strings::StrCat("{", + str_util::Join(nodes, ",", + [](string* output, const Node* node) { + strings::StrAppend(output, + node->name()); + }), + "}"); +} + // Copies a subgraph from `graph` to `output` by performing a reverse DFS // starting at nodes in vector `stack`. // `node_map` is a vector indexed by source node ID to dest nodes. // Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` -// before the traversal clients can cut the graph. Returns an error if the +// before the traversal clients can cut the graph. If a frame is provided (frame +// != nullptr), then this functions will return an error if the // traversal leaves 'frame'; the client must add enough nodes to `node_map` to // cut the graph and prevent the traversal from escaping. // @@ -84,25 +101,26 @@ struct Frame { // taking from the Switch node was not necessarily the first output, but _Arg // nodes only have one output. By adding the Switch node to `squash_src_outputs` // we rewrite the src_output of the corresponding edge to be 0. -Status CopySubgraph(const Graph& graph, const Frame& frame, +Status CopySubgraph(const Graph& graph, const Frame* frame, std::vector stack, const std::vector& squash_src_outputs, std::vector* node_map, Graph* output) { + VLOG(3) << "Stack: " << NodesToString(stack); std::vector visited(graph.num_node_ids(), false); while (!stack.empty()) { Node* n = stack.back(); stack.pop_back(); - VLOG(3) << "Copying node " << n->name(); + VLOG(5) << "Copying node " << n->name(); if (visited[n->id()]) continue; visited[n->id()] = true; for (const Edge* e : n->in_edges()) { Node* src = e->src(); - if (frame.nodes.find(src) == frame.nodes.end()) { + if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) { // We traversed out of the loop frame, without encountering a cut node. - return errors::Internal("Graph traversal of loop frame ", frame.name, + return errors::Internal("Graph traversal of loop frame ", frame->name, " escaped frame at ", src->name(), " without encountering an argument node."); } @@ -119,27 +137,31 @@ Status CopySubgraph(const Graph& graph, const Frame& frame, return Status::OK(); } -Status BuildArgNode(Graph* graph, DataType type, int index, Node** arg_node) { +xla::StatusOr AddNode(const NodeDef& node_def, Graph* graph) { + Status status; + Node* inserted_node = graph->AddNode(node_def, &status); + if (!status.ok()) { + return status; + } + return inserted_node; +} + +xla::StatusOr BuildArgNode(Graph* graph, DataType type, int index) { NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat("_Arg", index), kArgOp); + NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); builder.Attr("T", type); builder.Attr("index", index); TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); - Status status; - *arg_node = graph->AddNode(arg_def, &status); - return status; + return AddNode(arg_def, graph); } -Status BuildRetvalNode(Graph* graph, DataType type, int index, - Node** retval_node) { +xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { NodeDef ret_def; ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat("_Retval", index)); + ret_def.set_name(strings::StrCat(kRetValOp, index)); AddNodeAttr("T", type, &ret_def); AddNodeAttr("index", index, &ret_def); - Status status; - *retval_node = graph->AddNode(ret_def, &status); - return status; + return AddNode(ret_def, graph); } // Builds a graph for the loop condition. @@ -157,9 +179,8 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame, for (int i = 0; i < frame->args.size(); ++i) { const Arg& arg = frame->args[i]; - Node* arg_node; - TF_RETURN_IF_ERROR( - BuildArgNode(output, arg.enter->input_type(0), i, &arg_node)); + TF_ASSIGN_OR_RETURN(Node * arg_node, + BuildArgNode(output, arg.enter->input_type(0), i)); if (arg.is_loop_invariant) { node_map[arg.enter->id()] = arg_node; } else { @@ -169,16 +190,14 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame, // Build a Retval node for the loop condition. The LoopCond nodes are always // boolean because of the type constraints on the LoopCond op. - TF_RETURN_IF_ERROR( - BuildRetvalNode(output, DT_BOOL, 0, &node_map[frame->loop_cond->id()])); + TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()], + BuildRetvalNode(output, DT_BOOL, 0)); // Performs a reverse DFS, copying nodes and edges to the output graph. // The _Arg and _Retval nodes were added unconditionally above, so we are // guaranteed to get the correct function signature. - TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, {frame->loop_cond}, - squash_src_outputs, &node_map, output)); - - return Status::OK(); + return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs, + &node_map, output); } // Builds a graph for the loop body. @@ -202,8 +221,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, DataType dtype = arg.enter->input_type(0); arg_types->push_back(dtype); - Node* arg_node; - TF_RETURN_IF_ERROR(BuildArgNode(output, dtype, i, &arg_node)); + + TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i)); if (dtype == DT_RESOURCE) { // The convention of the XLA bridge is that resource variable arguments @@ -213,8 +232,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, TF_RET_CHECK(arg.is_loop_invariant); node_map[arg.enter->id()] = arg_node; } else { - Node* retval_node; - TF_RETURN_IF_ERROR(BuildRetvalNode(output, dtype, i, &retval_node)); + TF_ASSIGN_OR_RETURN(Node * retval_node, + BuildRetvalNode(output, dtype, i)); if (arg.is_loop_invariant) { // Argument is loop-invariant. Forward it from the Arg to the Retval. @@ -237,7 +256,7 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, // Performs a reverse DFS, copying nodes and edges to the output graph. // The _Arg and _Retval nodes were added unconditionally above, so we are // guaranteed to get the correct function signature. - TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, std::move(next_iterations), + TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), squash_src_outputs, &node_map, output)); return Status::OK(); @@ -396,10 +415,6 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, arg.exit = edge->dst(); } } - if (arg.exit == nullptr) { - return errors::InvalidArgument("Missing Exit successor to ", - arg.switch_node->name()); - } } } @@ -450,12 +465,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, } builder.Input(inputs); TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); - - Status status; - Node* while_node = graph->AddNode(while_def, &status); - if (!status.ok()) { - return status; - } + TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph)); // Copies edges to the Enter nodes and from the Exit nodes onto the While. for (int i = 0; i < frame->args.size(); ++i) { @@ -469,16 +479,21 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, } if (!arg.is_loop_invariant) { - std::vector edges(arg.exit->out_edges().begin(), - arg.exit->out_edges().end()); - for (const Edge* edge : edges) { - Node* dst = edge->dst(); - int dst_input = edge->dst_input(); - graph->RemoveEdge(edge); - - int src_output = - dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; - graph->AddEdge(while_node, src_output, dst, dst_input); + // Add output edges if the output of the loop is consumed. + if (arg.exit != nullptr) { + std::vector edges(arg.exit->out_edges().begin(), + arg.exit->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + + if (dst_input == Graph::kControlSlot) { + graph->AddControlEdge(while_node, dst); + } else { + graph->AddEdge(while_node, i, dst, dst_input); + } + } } } } @@ -488,6 +503,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, for (Node* node : frame->nodes) { graph->RemoveNode(node); } + frame->nodes.clear(); frame->parent->nodes.insert(while_node); VLOG(2) << "Frame " << frame->name << " after: " @@ -496,13 +512,863 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, return Status::OK(); } +class FunctionalizeCond { + public: + // Identifies the connected parts of the tf.Cond. + struct ClusterHandle { + explicit ClusterHandle(int representative = -1) + : representative(representative) {} + + bool operator==(const ClusterHandle& other) const { + return representative == other.representative; + } + + bool operator!=(const ClusterHandle& other) const { + return !(*this == other); + } + + bool operator<(const ClusterHandle& other) const { + return representative < other.representative; + } + + bool operator>(const ClusterHandle& other) const { + return representative > other.representative; + } + + string ToString() const { + return strings::StrCat("Cluster_", representative); + } + + // Vector of UnionFind indexable by ClusterHandle and Node*. + struct Vector { + explicit Vector(size_t size) : clusters(size) {} + + UnionFind& at(const ClusterHandle& cluster) { + return clusters.at(cluster.representative); + } + + UnionFind& at(const Node* node) { + return clusters.at(node->id()); + } + + UnionFind& operator[](const Node* node) { + return clusters.at(node->id()); + } + + size_t size() const { return clusters.size(); } + + void resize(size_t count) { return clusters.resize(count); } + + private: + std::vector> clusters; + }; + + private: + int representative; + }; + + // Represents a node in the clustered graph consisting of switch_nodes, + // merge_nodes as well as the edges into and out of this node to other + // Clusters. Each Cluster corresponds to a ClusterHandle and has a + // corresponding representative. + struct Cluster { + std::unordered_set switch_nodes; + std::unordered_set merge_nodes; + std::unordered_set in_nodes; + std::unordered_set out_nodes; + + // A member of the ClusterHandle corresponding to this Cluster. + ClusterHandle representative; + bool visited = false; + }; + + // Represent the clustered graph as map from cluster representative to + // Cluster. + using ClusteredGraph = std::map; + + // The arguments and condition of a XlaIf. The arguments are ordered by node + // id in the original graph. + struct CondArgs { + struct CondCmp { + bool operator()(const Node* lhs, const Node* rhs) const { + bool lhs_is_resource = + lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; + bool rhs_is_resource = + rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; + return std::tie(lhs_is_resource, lhs->name()) < + std::tie(rhs_is_resource, rhs->name()); + } + }; + Node* conditional = nullptr; + std::set args; + }; + + static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); + + private: + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) + : clusters_(graph->num_node_ids()), library_(library), graph_(graph) {} + + // Returns a vector of Merge nodes from the clustered graph where the nodes + // are sorted by the number of switch nodes minus number of merge nodes + // from a root of the clustered graph to the given Merge node, with ties + // broken by the representative of the Cluster. + std::vector> SortedMergeNodes(); + + // Returns whether the graph has no conditionals. + bool NoConditionals() const { return merge_nodes_.empty(); } + + // Construct the clustered graph by creating nodes for each cluster and the + // connections between the clusters. Switch and Merge nodes partition + // clusters, so iterate over those. Note: a Cluster may have neither a + // Merge or Switch but will have an in/out edge from a Cluster that has. + void CreateClusters(); + + // Creates the clustered graph by identifying all the edges between different + // clusters and collecting all switch and merge nodes that correspond to a + // cluster. + void CreateClusteredGraph(); + + // If `from` and `to` correspond to different clusters, then merge the nodes + // in the clustered graph corresponding to `from` and `to`. + // + // If `remove_from_graph` is specified then the `from` node is also removed + // from the clustered graph post contracting the edge. + void ContractEdge(Cluster* from, Cluster* to, bool remove_from_graph = false); + + // Converts a Merge node to a XlaIf. This encapsulates the process of + // extracting the bodies needed for the then and else branch, creates a XlaIf + // node, removing the nodes of the branches from the graph and replacing the + // merge node with a XlaIf. + Status ConvertMergeToXlaIf(Cluster* merge_cluster); + + // Removes a Switch cluster feeding directly into a Merge cluster by removing + // the Switch and Merge nodes and collapsing into a single cluster. + Status RemoveTrivialMerge(Cluster* merge_cluster); + + // Returns the switch cluster corresponding to the merge node. This function + // only returns the switch cluster in the simple case where we have a switch + // node is the entry of a diamond corresponding to a conditional: + // + // Switch + // / \ + // Branch Branch + // \ / + // merge_cluster + // + // Note: either of the branches may be empty. The case where both branches are + // empty is handled by RemoveTrivialMerge. + gtl::optional GetSwitchCluster(const Cluster& merge_cluster); + + // Determines the arguments needed as input to the Merge cluster originating + // from the Switch cluster. + xla::StatusOr DetermineCondArgs(const Cluster& merge_cluster, + const Cluster& switch_cluster); + + // Builds a XlaIfOp to replace the Merge node with. + xla::StatusOr BuildAndAddXlaIfOp(const CondArgs& cond_args, + const Cluster& merge_cluster, + const std::vector& outputs); + + // Extracts a function body corresponding to the given input edge of the merge + // node. + Status ExtractBody(const CondArgs& cond_args, const Cluster& merge_cluster, + const std::vector& outputs, int input_edge, + Graph* body); + + // Adds all the input edges to `if_node` corresponding to the arguments. + Status AddInputEdges(const CondArgs& cond_args, Node* if_node); + + // Adds all output edges from the `if_node`. + Status AddOutputEdges(const std::vector& outputs, Node* if_node); + + // Removes all nodes from the graph that are part of cluster. + void RemoveClusterNodes(Cluster* cluster); + + // Removes all argument nodes that are unused. + template + void RemoveUnusedArgs(const T& args); + + // Removes all Merge nodes in merge_cluster. + void RemoveMergeNodes(Cluster* merge_cluster); + + // Returns the representative member of the corresponding cluster. + ClusterHandle Representative(const Node* node) { + return clusters_.at(node).Get(); + } + + ClusteredGraph clustered_graph_; + ClusterHandle::Vector clusters_; + std::unordered_set merge_nodes_; + std::unordered_set switch_nodes_; + FunctionLibraryDefinition* library_; + Graph* graph_; +}; + +std::ostream& operator<<(std::ostream& os, + const FunctionalizeCond::ClusterHandle& c) { + os << c.ToString(); + return os; +} + +// Returns a dot representation of the clustered graph showing the connections +// between the nodes and the nodes in each cluster. +string DebugString(const Graph& graph, + FunctionalizeCond::ClusterHandle::Vector* clusters) { + string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n"; + std::map subgraphs; + for (Node* n : graph.nodes()) { + if (n->IsOp()) { + strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), + " [label=\"", n->name(), "\"];\n"); + } + } + for (auto kv : subgraphs) { + strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n", + "style=filled; color=lightgrey;", "label = \"", + kv.first.ToString(), "\";\n", kv.second, "}\n"); + } + for (Node* n : graph.nodes()) { + if (!n->IsOp()) { + continue; + } + for (Node* in : n->in_nodes()) { + if (in->IsOp()) { + strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); + } + } + } + return strings::StrCat(ret, "}"); +} + +string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { + string ret = "digraph {\ncompound=true;labeljust=\"r\";\n"; + auto name = [](const FunctionalizeCond::Cluster& cluster) { + return cluster.representative.ToString(); + }; + for (auto kv : clustered_graph) { + strings::StrAppend(&ret, kv.first.ToString(), " [label=\"", name(kv.second), + " (", kv.second.switch_nodes.size(), ", ", + kv.second.merge_nodes.size(), ")\"];\n"); + } + for (auto kv : clustered_graph) { + for (auto in : kv.second.in_nodes) { + strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n"); + } + } + return strings::StrCat(ret, "}"); +} + +bool IsDeadSwitch(const Node* node) { + for (const Edge* e : node->out_edges()) { + const Node* dst = e->dst(); + if (!dst->IsIdentity()) { + return false; + } + for (const Edge* ee : dst->out_edges()) { + if (!ee->IsControlEdge() || !ee->dst()->IsSink()) { + return false; + } + } + } + return true; +} + +void FunctionalizeCond::CreateClusters() { + for (Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + if (IsSwitch(node)) { + switch_nodes_.insert(node); + } else if (IsMerge(node)) { + merge_nodes_.insert(node); + } + ClusterHandle& cluster = clusters_.at(node).Get(); + cluster = ClusterHandle(node->id()); + } + + // If there are no Merge nodes, then terminate. + if (merge_nodes_.empty()) { + return; + } + + // Remove all dead Switch nodes. + RemoveUnusedArgs(switch_nodes_); + + // All parent_'s are still nullptr so clusters_ may still be resized. Resize + // conservatively assuming all merge nodes become XlaIf nodes. + clusters_.resize(clusters_.size() + merge_nodes_.size()); + + // Merge a cluster with its input, unless the input is a Switch node or + // the node is a Merge node. + for (const Node* node : graph_->nodes()) { + if (IsMerge(node) || IsSwitch(node) || !node->IsOp()) { + continue; + } + for (const Node* in : node->in_nodes()) { + if (in->IsOp() && !IsSwitch(in) && !IsMerge(in)) { + clusters_.at(node).Merge(&clusters_.at(in)); + } + } + } +} + +void FunctionalizeCond::ContractEdge(Cluster* from, Cluster* to, + bool remove_from_graph) { + VLOG(3) << "ContractEdge from = " << from->representative + << " to = " << to->representative; + if (from->representative == to->representative) { + return; + } + to->merge_nodes.insert(from->merge_nodes.begin(), from->merge_nodes.end()); + from->merge_nodes.clear(); + to->switch_nodes.insert(from->switch_nodes.begin(), from->switch_nodes.end()); + from->switch_nodes.clear(); + + for (Cluster* from_out : from->out_nodes) { + from_out->in_nodes.erase(from); + if (from_out->representative != to->representative) { + from_out->in_nodes.insert(to); + to->out_nodes.insert(from_out); + } + } + from->out_nodes.clear(); + + for (Cluster* from_in : from->in_nodes) { + from_in->out_nodes.erase(from); + if (from_in->representative != to->representative) { + from_in->out_nodes.insert(to); + to->in_nodes.insert(from_in); + } + } + from->in_nodes.clear(); + + to->in_nodes.erase(from); + to->out_nodes.erase(from); + clusters_.at(to->representative).Merge(&clusters_.at(from->representative)); + from->visited = true; + + if (remove_from_graph) { + clustered_graph_.erase(from->representative); + } +} + +void FunctionalizeCond::CreateClusteredGraph() { + auto update_cluster_for_node = [this](Node* node) -> Cluster& { + ClusterHandle repr = Representative(node); + Cluster& cluster_node = clustered_graph_[repr]; + cluster_node.representative = repr; + for (const Node* in : node->in_nodes()) { + ClusterHandle other_repr = Representative(in); + // Skip source, sink and internal edges. + if (!in->IsOp() || other_repr == repr) { + continue; + } + Cluster& cluster_node_in = clustered_graph_[other_repr]; + cluster_node.in_nodes.insert(&cluster_node_in); + cluster_node_in.out_nodes.insert(&cluster_node); + cluster_node_in.representative = other_repr; + } + for (const Node* out : node->out_nodes()) { + ClusterHandle other_repr = Representative(out); + // Skip source, sink and internal edges. + if (!out->IsOp() || other_repr == repr) { + continue; + } + Cluster& cluster_node_out = clustered_graph_[other_repr]; + cluster_node.out_nodes.insert(&cluster_node_out); + cluster_node_out.in_nodes.insert(&cluster_node); + cluster_node_out.representative = other_repr; + } + return cluster_node; + }; + for (Node* node : switch_nodes_) { + update_cluster_for_node(node).switch_nodes.insert(node); + } + for (Node* node : merge_nodes_) { + update_cluster_for_node(node).merge_nodes.insert(node); + } + + // Merge Switch nodes with common predicate. + std::unordered_map> predicate_to_switch; + for (Node* node : switch_nodes_) { + Node* tmp; + TF_CHECK_OK(node->input_node(1, &tmp)); + predicate_to_switch[tmp].push_back(node); + } + for (auto kv : predicate_to_switch) { + Cluster& first = clustered_graph_.at(Representative(kv.second.front())); + for (Node* switch_node : kv.second) { + ClusterHandle handle = Representative(switch_node); + Cluster& cluster = clustered_graph_.at(handle); + ContractEdge(&cluster, &first, /*remove_from_graph=*/true); + } + } + + // Merge Merge nodes with common input together. + for (Node* node : merge_nodes_) { + Cluster& cluster = clustered_graph_.at(Representative(node)); + for (const Node* in : node->in_nodes()) { + if (!in->IsOp()) { + continue; + } + Cluster& cluster_node_in = clustered_graph_.at(Representative(in)); + // ContractEdge can modify out_nodes of cluster_node_in, so traverse + // over out_nodes assuming it does. + for (auto it = cluster_node_in.out_nodes.begin(); + it != cluster_node_in.out_nodes.end();) { + if (!(*it)->merge_nodes.empty()) { + ContractEdge(*it++, &cluster, /*remove_from_graph=*/true); + } else { + ++it; + } + } + } + } + + VLOG(3) << "Graph with clusters: " << DebugString(*graph_, &clusters_); + VLOG(3) << "ClusteredGraph: " << DebugString(clustered_graph_); +} + +gtl::optional FunctionalizeCond::GetSwitchCluster( + const Cluster& merge_cluster) { + VLOG(3) << "GetSwitchCluster for " << merge_cluster.representative; + gtl::optional switch_cluster; + if (merge_cluster.in_nodes.size() > 2) { + return gtl::nullopt; + } + for (Cluster* in : merge_cluster.in_nodes) { + Cluster* cluster = in; + if (in->switch_nodes.empty()) { + if (in->in_nodes.size() != 1) { + return gtl::nullopt; + } + // There is only a single `in` cluster. + cluster = *in->in_nodes.begin(); + } + if (cluster->switch_nodes.empty()) { + return gtl::nullopt; + } + + if (switch_cluster.has_value() && *switch_cluster != cluster) { + return gtl::nullopt; + } else { + switch_cluster = cluster; + } + } + return switch_cluster; +} + +xla::StatusOr FunctionalizeCond::DetermineCondArgs( + const Cluster& merge_cluster, const Cluster& switch_cluster) { + VLOG(2) << "DetermineCondArgs for " << merge_cluster.representative + << " with switch cluster " << switch_cluster.representative; + CondArgs ret; + auto feeds_into_branch_cluster = [&](Node* switch_cluster) { + for (Node* out : switch_cluster->out_nodes()) { + ClusterHandle repr = Representative(out); + if (repr == merge_cluster.representative) { + return true; + } + for (Cluster* in : merge_cluster.in_nodes) { + if (repr == in->representative) { + return true; + } + } + } + return false; + }; + for (Node* switch_cluster_node : switch_cluster.switch_nodes) { + if (!feeds_into_branch_cluster(switch_cluster_node)) { + continue; + } + + Node* tmp; + TF_RETURN_IF_ERROR(switch_cluster_node->input_node(1, &tmp)); + if (ret.conditional == nullptr) { + ret.conditional = tmp; + } else if (ret.conditional != tmp) { + return errors::Unimplemented( + "Switch statements with different conditionals cannot be " + "converted into functional conditional."); + } + ret.args.insert(switch_cluster_node); + } + return ret; +} + +xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( + const CondArgs& cond_args, const Cluster& merge_cluster, + const std::vector& outputs) { + VLOG(2) << "Build if op for " << NodesToString(merge_cluster.merge_nodes) + << " with input " << NodesToString(cond_args.args); + + NodeDef if_def; + // Create a new If node using the name of the merge node. + NodeDefBuilder builder( + strings::StrCat((*merge_cluster.merge_nodes.begin())->name(), "_If"), + "XlaIf"); + string branch[] = {"else_branch", "then_branch"}; + for (int i = 0; i < 2; ++i) { + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + + NameAttrList body_name; + body_name.set_name( + strings::StrCat("_functionalize_if_", branch[i], "_", id)); + auto body = xla::MakeUnique(graph_->op_registry()); + TF_RETURN_IF_ERROR( + ExtractBody(cond_args, merge_cluster, outputs, i, body.get())); + VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); + FunctionDef body_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); + TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef)); + builder.Attr(branch[i], body_name); + } + + // Build input type. + std::vector inputs; + DataTypeVector in_arg_types; + for (const Node* arg : cond_args.args) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + DataType dtype = arg->input_type(0); + inputs.emplace_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), dtype)); + in_arg_types.push_back(dtype); + } + } + builder.Attr("Tin", in_arg_types); + + // Build output type. + DataTypeVector out_type; + for (const Node* merge : merge_cluster.merge_nodes) { + DataType dtype = merge->output_type(0); + out_type.push_back(dtype); + } + builder.Attr("Tout", out_type); + + builder.Attr("Tcond", DT_BOOL); + builder.Device(cond_args.conditional->assigned_device_name()); + // Conditional should be the first input ... + builder.Input(NodeDefBuilder::NodeOut(cond_args.conditional->name(), 0, + cond_args.conditional->output_type(0))); + // ... followed by the other inputs. + builder.Input(inputs); + + TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); + TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_)); + return if_node; +} + +void FunctionalizeCond::RemoveClusterNodes(Cluster* cluster) { + VLOG(3) << "RemoveClusterNodes for " << cluster->representative; + ClusterHandle repr = cluster->representative; + std::deque to_delete; + for (Node* node : graph_->nodes()) { + if (Representative(node) == repr) { + to_delete.push_back(node); + } + } + for (Node* n : to_delete) { + graph_->RemoveNode(n); + } +} + +template +void FunctionalizeCond::RemoveUnusedArgs(const T& args) { + VLOG(2) << "RemoveUnusedArgs among: " << NodesToString(args); + + std::deque to_delete; + for (Node* arg : args) { + if (IsDeadSwitch(arg)) { + to_delete.push_back(arg); + for (Node* n : arg->out_nodes()) { + to_delete.push_back(n); + } + } + } + for (Node* n : to_delete) { + switch_nodes_.erase(n); + auto it = clustered_graph_.find(Representative(n)); + if (it != clustered_graph_.end()) { + it->second.switch_nodes.erase(n); + } + graph_->RemoveNode(n); + } +} + +Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, + const Cluster& merge_cluster, + const std::vector& outputs, + int input_edge, Graph* body) { + VLOG(2) << "ExtractBody for " << merge_cluster.representative + << " along edge " << input_edge; + std::vector squash_src_outputs(graph_->num_node_ids(), false); + std::vector node_map(graph_->num_node_ids(), nullptr); + int arg_count = 0; + for (const auto* arg : cond_args.args) { + DataType dtype = arg->input_type(0); + TF_ASSIGN_OR_RETURN(Node * arg_node, + BuildArgNode(body, dtype, arg_count++)); + node_map.at(arg->id()) = arg_node; + squash_src_outputs.at(arg->id()) = true; + } + + std::vector stack; + stack.reserve(outputs.size()); + for (int j = 0; j < outputs.size(); ++j) { + Node* node = outputs[j]; + TF_ASSIGN_OR_RETURN(node_map.at(node->id()), + BuildRetvalNode(body, node->output_type(0), + /*index=*/j)); + const Edge* in_edge; + TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge)); + Node* in = in_edge->src(); + if (node_map.at(in->id()) == nullptr) { + node_map.at(in->id()) = body->CopyNode(in); + } + + if (cond_args.args.find(in) == cond_args.args.end()) { + body->AddEdge(node_map.at(in->id()), in_edge->src_output(), + node_map.at(node->id()), 0); + } else { + body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0); + // Don't include input nodes that are already just returned in stack. + continue; + } + stack.push_back(in); + } + + return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map, + body); +} + +Status FunctionalizeCond::AddInputEdges(const CondArgs& cond_args, + Node* if_node) { + VLOG(3) << "AddInputEdges for " << if_node->name(); + int i = 0; + graph_->AddEdge(cond_args.conditional, 0, if_node, i++); + for (const Node* arg : cond_args.args) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + graph_->AddControlEdge(in_edge->src(), if_node); + } else { + graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, i++); + } + } + return Status::OK(); +} + +Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, + Node* if_node) { + VLOG(3) << "AddOutputEdges for " << if_node->name(); + for (int i = 0; i < outputs.size(); ++i) { + Node* node = outputs[i]; + std::vector edges(node->out_edges().begin(), + node->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + + if (edge->src_output() > 0) { + return errors::Unimplemented("Output of index (", edge->src_output(), + ") of merge node ", node->name()); + } + graph_->RemoveEdge(edge); + + int src_output = + dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; + graph_->AddEdge(if_node, src_output, dst, dst_input); + } + } + return Status::OK(); +} + +void FunctionalizeCond::RemoveMergeNodes(Cluster* merge_cluster) { + VLOG(3) << "RemoveMergeNodes for " << merge_cluster->representative; + // Remove all merge nodes now dead post extraction of If. + for (auto it = merge_cluster->merge_nodes.begin(); + it != merge_cluster->merge_nodes.end();) { + Node* node = *it; + graph_->RemoveNode(node); + merge_cluster->merge_nodes.erase(*it++); + } +} + +Status FunctionalizeCond::RemoveTrivialMerge(Cluster* merge_cluster) { + Cluster* switch_cluster = *merge_cluster->in_nodes.begin(); + if (switch_cluster->switch_nodes.empty()) { + return errors::FailedPrecondition( + "Not a trivial merge: no Switch node feeding into Merge node"); + } + + for (auto it = merge_cluster->merge_nodes.begin(); + it != merge_cluster->merge_nodes.end();) { + // We have the following structure: + // Op -> Switch -> Merge -> Consumer + // and we want to transform it to: + // Op -> Consumer + Node* merge_node = *it; + Node* switch_node; + const Edge* in = nullptr; + TF_RETURN_IF_ERROR(merge_node->input_node(0, &switch_node)); + TF_RETURN_IF_ERROR(switch_node->input_edge(0, &in)); + for (auto out : merge_node->out_edges()) { + int src_output = out->dst_input() == Graph::kControlSlot + ? Graph::kControlSlot + : in->src_output(); + graph_->AddEdge(in->src(), src_output, out->dst(), out->dst_input()); + } + graph_->RemoveNode(*it++); + } + RemoveUnusedArgs(switch_cluster->switch_nodes); + + return Status::OK(); +} + +Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) { + VLOG(1) << "ConvertMergeToXlaIf for " << merge_cluster->representative; + gtl::optional switch_cluster = GetSwitchCluster(*merge_cluster); + if (!switch_cluster.has_value()) { + return errors::FailedPrecondition( + "Merge cluster was not part of a simple conditional in the clustered " + "graph. Graph nodes in merge cluster ", + NodesToString(merge_cluster->merge_nodes)); + } + TF_ASSIGN_OR_RETURN(auto cond_args, + DetermineCondArgs(*merge_cluster, **switch_cluster)); + + // Sort the outputs by ID to produce more stable output. + std::vector outputs(merge_cluster->merge_nodes.begin(), + merge_cluster->merge_nodes.end()); + std::sort(outputs.begin(), outputs.end(), CondArgs::CondCmp()); + + // Extract bodies and builds a If operator. + TF_ASSIGN_OR_RETURN(Node * if_node, + BuildAndAddXlaIfOp(cond_args, *merge_cluster, outputs)); + TF_RETURN_IF_ERROR(AddInputEdges(cond_args, if_node)); + TF_RETURN_IF_ERROR(AddOutputEdges(outputs, if_node)); + + // Remove the old nodes from the graph_ and contract the edges of the + // clustered graph. + for (auto in : merge_cluster->in_nodes) { + if (in != *switch_cluster) { + RemoveClusterNodes(in); + } + } + RemoveMergeNodes(merge_cluster); + RemoveUnusedArgs(cond_args.args); + auto in_nodes = merge_cluster->in_nodes; + for (auto it = in_nodes.begin(); it != in_nodes.end();) { + ContractEdge(*it++, merge_cluster); + } + ContractEdge(*switch_cluster, merge_cluster); + clusters_[if_node].Get() = ClusterHandle(merge_cluster->representative); + + return Status::OK(); +} + +std::vector> +FunctionalizeCond::SortedMergeNodes() { + VLOG(2) << "ProcessClusteredGraph"; + std::stack> stack; + for (auto& c : clustered_graph_) { + if (c.second.in_nodes.empty()) { + stack.push({0, &c.second}); + } + } + + // Perform a depth-first traversal of the clustered graph computing the + // switch-merge depth. + std::vector> queue; + std::unordered_set visited; + while (!stack.empty()) { + Cluster* n = stack.top().second; + size_t depth = stack.top().first; + stack.pop(); + + auto inserted = visited.insert(n); + if (!inserted.second) { + continue; + } + + size_t new_depth = depth; + if (!n->merge_nodes.empty()) { + queue.emplace_back(depth, n); + --new_depth; + } + if (!n->switch_nodes.empty()) { + ++new_depth; + } + for (Cluster* e : n->out_nodes) { + stack.emplace(new_depth, e); + } + } + + // Sort in reverse order of switch-merge depth with ties broken by the + // ClusterHandle. + std::sort(queue.begin(), queue.end(), + [](const std::pair& lhs, + const std::pair& rhs) { + return std::tie(lhs.first, lhs.second->representative) > + std::tie(rhs.first, rhs.second->representative); + }); + + return queue; +} + +Status FunctionalizeCond::Functionalize(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(1) << "FunctionalizeCond::Functionalize"; + FunctionalizeCond fc(graph, library); + fc.CreateClusters(); + if (fc.NoConditionals()) { + return Status::OK(); + } + fc.CreateClusteredGraph(); + + auto queue = fc.SortedMergeNodes(); + for (auto it = queue.begin(); it != queue.end();) { + Cluster* merge_cluster = (*it).second; + ++it; + if (merge_cluster->in_nodes.size() == 1) { + TF_RETURN_IF_ERROR(fc.RemoveTrivialMerge(merge_cluster)); + } else { + TF_RETURN_IF_ERROR(fc.ConvertMergeToXlaIf(merge_cluster)); + } + + // Contract newly Merge free merge_cluster with incoming nodes without + // Switch or Merge nodes. + std::vector in_nodes(merge_cluster->in_nodes.begin(), + merge_cluster->in_nodes.end()); + for (auto in : in_nodes) { + if (in->merge_nodes.empty() && in->switch_nodes.empty()) { + fc.ContractEdge(in, merge_cluster); + } + } + } + + if (!fc.switch_nodes_.empty()) { + return errors::Internal( + "Failed to functionalize control flow with Switch nodes remaining: ", + NodesToString(fc.switch_nodes_)); + } + return Status::OK(); +} + } // namespace // Transformation that converts Tensorflow's graph control flow constructs into // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { - VLOG(2) << "FunctionalizeControlFlow: " + VLOG(2) << "FunctionalizeControlFlow (initial): " << dump_graph::DumpGraphToFile("functionalize_initial", *graph); // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this @@ -577,6 +1443,13 @@ Status FunctionalizeControlFlow(Graph* graph, } } + // FunctionalizeControlFlow is invoked for every function, so the loops's + // bodies and conditionals that were extracted into functions will be handled + // in successive invocations. + TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library)); + + VLOG(2) << "FunctionalizeControlFlow (final): " + << dump_graph::DumpGraphToFile("functionalize_final", *graph); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 1535dc80b0ccdba38c57b534ed7473fc8632e33f..4d4ee3054c2914bb614bf75f7a51be8f6292683e 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -23,7 +23,6 @@ namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While // operators, suitable for XLA compilation. -// TODO(b/36470387): add support for conditionals. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 914c8999a6f13f5f2dc4e3cecc38c91afd432131..01d2b282751f387cfa9c8887cdeb48090c96bff4 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" @@ -35,6 +36,134 @@ limitations under the License. namespace tensorflow { namespace { +// Returns the names of the "then" and "else" functions for the XlaIf node in a +// graph. +Status FindIfThenAndElse(const GraphDef& graph, NameAttrList* then_fn, + NameAttrList* else_fn) { + for (const NodeDef& node : graph.node()) { + if (node.op() == "XlaIf") { + const NameAttrList* result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); + *then_fn = *result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result)); + *else_fn = *result; + return Status::OK(); + } + } + return errors::NotFound("No XlaIf node found in graph"); +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = array_ops.placeholder(dtypes.int32) +// z = control_flow_ops.cond( +// math_ops.less(y, x), lambda: math_ops.multiply(y, 17), +// lambda: math_ops.add(x, 23)) +TEST(FunctionalizeControlFlow, Conditional) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less); + + auto identity_t = + ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true); + auto seventeen = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity_t), 17); + auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less); + auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true, + seventeen); + + auto identity_f = + ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false); + auto twenty_three = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity_f), 23); + auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); + auto add = ops::Add(scope.WithOpName("cond/false/add"), + switch_3.output_false, twenty_three); + + auto merge = ops::Merge(scope.WithOpName("cond/Merge"), + std::initializer_list{add, mul}); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + NameAttrList then_fn; + NameAttrList else_fn; + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &then_fn, &else_fn)); + InstantiationResultForTest else_result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto if_op = ops::XlaIf(scope.WithOpName("cond/Merge_If"), less, + std::initializer_list{less, y, x}, then_fn, + else_fn, {DT_INT32}); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // then body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); + auto cond = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity), 17); + auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // else body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); + auto cond_1 = ops::Const( + scope.WithOpName("cond_1").WithControlDependencies(identity), 23); + auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Returns the names of the "cond" and "body" functions for the While node // in a graph. Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, @@ -168,6 +297,108 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } } +// Tests functionalizing OneLoopVar where the loop value is not used post the +// loop. +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) +TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = + ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto next_iteration = + ops::NextIteration(scope.WithOpName("while/NextIteration"), add); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Graph: // x = array_ops.placeholder(dtypes.int32) // y = array_ops.placeholder(dtypes.int32) diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..8062f0c03ca60e88bd5c021092dceb105232219f --- /dev/null +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -0,0 +1,245 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/graph_compiler.h" + +#include +#include +#include +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +namespace { +Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, + const std::vector& expressions, + std::vector* args) { + auto builder = ctx->builder(); + std::vector compile_time_constant_flags(expressions.size()); + + TF_RETURN_IF_ERROR( + BackwardsConstAnalysis(*graph, &compile_time_constant_flags)); + + args->resize(expressions.size()); + for (int i = 0; i < args->size(); ++i) { + XlaCompiler::Argument& arg = (*args)[i]; + arg.type = ctx->input_type(i); + + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + + if (arg.type == DT_RESOURCE) { + return errors::InvalidArgument( + "Resource as function argument is not yet implemented."); + } else if (expressions[i]->has_constant_value()) { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = expressions[i]->constant_value(); + } else if (compile_time_constant_flags[i]) { + arg.kind = XlaCompiler::Argument::kConstant; + TF_RET_CHECK(expressions[i]->resource() == nullptr) + << "Input with resource is not yet implemented."; + TF_ASSIGN_OR_RETURN(auto literal, + builder->ComputeConstant(expressions[i]->handle())); + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); + } else { + arg.kind = XlaCompiler::Argument::kParameter; + } + } + return Status::OK(); +} +} // namespace +Status GraphCompiler::Compile() { + // Maintain a mapping from node id to node outputs. + using NodeOutputs = std::vector; + std::vector output_registry(graph_->num_node_ids()); + auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] { + for (const NodeOutputs& outputs : output_registry) { + for (const TensorValue& value : outputs) { + CHECK(!value.is_ref()); + delete value.tensor; + } + } + }); + + // XLA requires determinism, generate a stable ordering from DFS. + std::vector topo_sorted_nodes; + GetReversePostOrder(*graph_, &topo_sorted_nodes, + /*stable_comparator=*/NodeComparatorName()); + + OpKernelContext::Params params; + PartiallySetupParams(¶ms); + + for (Node* n : topo_sorted_nodes) { + OpKernel* op_kernel_raw = nullptr; + Status s = flib_->CreateKernel(n->def(), &op_kernel_raw); + // Transfer ownership of the kernel to a local smart pointer. + std::unique_ptr op_kernel(op_kernel_raw); + + if (!s.ok()) { + s = AttachDef(s, *n); + LOG(ERROR) << "Executor failed to create kernel. " << s; + return s; + } + + TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch()) + << "Not supported node: " << n->DebugString(); + params.op_kernel = op_kernel.get(); + gtl::InlinedVector output_attr(n->num_outputs()); + params.output_attr_array = output_attr.data(); + + // tensor_inputs_ is a buffer reused across graph traversal. We clean up and + // reinitialize the buffer before we visit a new node. + tensor_inputs_.clear(); + tensor_inputs_.resize(n->num_inputs()); + + // Set up inputs from outputs of previous nodes. + for (auto* e : n->in_edges()) { + if (e->IsControlEdge()) continue; + Node* src = e->src(); + TF_RET_CHECK(src->id() < output_registry.size()); + const NodeOutputs& src_outputs = output_registry[src->id()]; + + tensor_inputs_[e->dst_input()] = src_outputs[e->src_output()]; + } + + OpKernelContext op_context(¶ms, n->num_outputs()); + if (IsFunctional(n)) { + TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context)); + } else { + device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context); + Status s = op_context.status(); + TF_RETURN_IF_ERROR(s); + } + + // Set up outputs. Also check if outputs from the previous computation is + // valid. + NodeOutputs& outputs = output_registry[n->id()]; + outputs.resize(n->num_outputs()); + for (int o = 0; o < n->num_outputs(); ++o) { + outputs[o] = op_context.release_output(o); + if (*op_context.is_output_dead() || outputs[o].tensor == nullptr) { + return errors::Internal("Missing xla_context ", o, "-th output from ", + (*op_context.is_output_dead() ? "(dead)" : ""), + SummarizeNode(*n)); + } + } + } + return Status::OK(); +} + +bool GraphCompiler::IsFunctional(Node* n) { + return n->type_string() == FunctionLibraryDefinition::kGradientOp || + (flib_->GetFunctionLibraryDefinition()->Find(n->def().op()) != + nullptr); +} + +Status GraphCompiler::CompileFunctionalNode(Node* n, + OpKernelContext* op_context) { + TF_RET_CHECK(IsFunctional(n)); + // For functional nodes, compile them using compiler from the context and call + // into the functions. + XlaOpKernelContext xla_op_context(op_context); + + XlaCompiler* compiler = xla_op_context.compiler(); + + NameAttrList func; + if (flib_->GetFunctionLibraryDefinition()->Find(n->def().op())) { + func.set_name(n->def().op()); + } else { + func.set_name(FunctionLibraryDefinition::kGradientOp); + } + *func.mutable_attr() = n->def().attr(); + + std::vector expressions; + + for (auto tensor : tensor_inputs_) { + auto expression = + reinterpret_cast(tensor->tensor_data().data()); + expressions.push_back(expression); + } + + // Prepare the arguments and compile the function. + std::vector arguments; + const FunctionBody* fbody; + TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody)); + + auto graph = compiler->GetGraph(fbody); + + TF_RETURN_IF_ERROR( + PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + + XlaCompiler::CompilationResult result; + + TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(), + func, arguments, &result)); + + TF_RET_CHECK(arguments.size() == expressions.size()); + + std::vector handles; + for (int64 i = 0; i < expressions.size(); ++i) { + if (arguments[i].kind == XlaCompiler::Argument::kConstant) { + continue; + } + handles.push_back(expressions[i]->handle()); + } + + XlaContext& context = XlaContext::Get(op_context); + auto* b = context.builder(); + + auto output_handle = b->Call(*result.computation, handles); + // The output handle of `Call` computation is a tuple type. Unzip it so + // that it can fit into future computations. + for (int64 i = 0; i < n->num_outputs(); ++i) { + if (result.outputs[i].is_constant) { + xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); + } else { + xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i)); + } + } + return b->first_error(); +} + +void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) { + params->device = device_; + params->inputs = &tensor_inputs_; + params->step_container = step_container_; + params->resource_manager = device_->resource_manager(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..ba00160b6d78c1e55cc2e053cd5285344e0179fb --- /dev/null +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +// GraphCompiler compiles the graph in topological order in the current +// thread. It also resolves the nondeterminism in the graph by enforcing a +// total order on all inputs to a node. This abstraction helps us create the +// same XLA computation given two structurally equivalent TensorFlow graphs. +// If a function call is visited during the graph traversal, it is then +// compiled through the xla_context into a computation and a `Call` operation +// is inserted to call into that computation. +// +// Note: GraphCompiler was created to remove our dependency to TF Executor in +// the history. There are still some todos so that we can completely decouple +// from Executor. +// +// TODO(yunxing): Remove usage of XlaCompilationDevice. +// +// TODO(yunxing): Remove the hack that wraps XlaExpression within a tensor now +// that we don't use TF Executor to pass around a tensor. +// +// TODO(yunxing): Make XlaOpkernel not a subclass of OpKernel so that it can +// handle a XlaExpression directly instead of a Tensor. This may require our own +// op registration infrastructure instead of FunctionLibraryRuntime. +class GraphCompiler { + public: + GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device, + Graph* graph, FunctionLibraryRuntime* flib, + ScopedStepContainer* step_container) + : xla_context_(xla_context), + device_(device), + graph_(graph), + flib_(flib), + step_container_(step_container) {} + + // Compiles the graph. The results are written in `xla_context` that is passed + // into the compiler. + Status Compile(); + + private: + // Partially sets params. This partially set params can be reused + // across multple nodes visit. + void PartiallySetupParams(OpKernelContext::Params* params); + + // Tests if a node is a functional node. A functional node represents a + // defined computation and should be compiled using `compiler_`. + bool IsFunctional(Node* n); + + // Compiles a functional node and writes result to OpkernelContext. A + // functional node represents a defined computation and should be compiled + // using `compiler_`. + Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); + + XlaContext* xla_context_; + XlaCompilationDevice* device_; + Graph* graph_; + FunctionLibraryRuntime* flib_; + ScopedStepContainer* step_container_; + // A buffer to hold tensor inputs to a node, this is reused across the graph + // traversal. + gtl::InlinedVector tensor_inputs_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 6a0c4fef7574319de2837f7a17867a41784d56b4..2b43e313eb42c288b891f97c0b6cd3cacdc77711 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -5,7 +5,6 @@ package( ) load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts") tf_kernel_library( name = "xla_ops", @@ -83,6 +82,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", @@ -132,8 +132,6 @@ tf_kernel_library( name = "xla_cpu_only_ops", srcs = ["index_ops_cpu.cc"], deps = [ - ":gather_op_kernel_float_int32", - ":gather_op_kernel_float_int64", ":index_ops_kernel_argmax_float_1d", ":index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/tf2xla:common", @@ -149,39 +147,12 @@ tf_kernel_library( ], ) -cc_library( - name = "gather_op_kernel_float_int32", - srcs = ["gather_op_kernel_float_int32.cc"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/core:framework_lite", - "//tensorflow/core/kernels:bounds_check", - "//tensorflow/core/kernels:gather_functor_hdr", - "//third_party/eigen3", - ], - alwayslink = 1, -) - -cc_library( - name = "gather_op_kernel_float_int64", - srcs = ["gather_op_kernel_float_int64.cc"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/core:framework_lite", - "//tensorflow/core/kernels:bounds_check", - "//tensorflow/core/kernels:gather_functor_hdr", - "//third_party/eigen3", - ], - alwayslink = 1, -) - cc_library( name = "index_ops_kernel_argmax_float_1d", srcs = ["index_ops_kernel_argmax_float_1d.cc"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], @@ -193,6 +164,7 @@ cc_library( srcs = ["index_ops_kernel_argmax_float_2d.cc"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 16b778bca439b9236498945f132e8095baeb71c1..73ccc151c1d6bdf70105badd962903297f090abe 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -77,7 +77,13 @@ class BatchMatMulOp : public XlaOpKernel { xla::ComputationBuilder* builder = ctx->builder(); xla::ComputationDataHandle x_handle = ctx->Input(0); + if (BaseType(input_type(0)) == DT_COMPLEX64 && adj_x_) { + x_handle = builder->Conj(x_handle); + } xla::ComputationDataHandle y_handle = ctx->Input(1); + if (BaseType(input_type(1)) == DT_COMPLEX64 && adj_y_) { + y_handle = builder->Conj(y_handle); + } // Reshape input tensors into 3D tensors by flattening the batch // dimensions. This makes it easier to unroll the batch dimension. diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 58538b45137b26ed5aa296eb6c1077e88aea72b9..1de91924326464338352b1ac9edf77141f25ad35 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Native XLA implementations of simple unary Ops +// Native XLA implementations of simple binary Ops #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace { @@ -50,6 +51,9 @@ XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Atan2, b->Atan2(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions)); + // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { // T abs_x = std::abs(x); @@ -96,8 +100,17 @@ static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b, XLA_MAKE_BINARY(FloorMod, FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -XLA_MAKE_BINARY(LogicalAnd, b->LogicalAnd(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LogicalOr, b->LogicalOr(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseAnd, b->And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseOr, b->Or(lhs, rhs, extend_dimensions)); + +XLA_MAKE_BINARY(LeftShift, b->ShiftLeft(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(RightShift, + (DataTypeIsUnsigned(ctx->input_type(0)) + ? b->ShiftRightLogical(lhs, rhs, extend_dimensions) + : b->ShiftRightArithmetic(lhs, rhs, extend_dimensions))); + +XLA_MAKE_BINARY(LogicalAnd, b->And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LogicalOr, b->Or(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions)); @@ -162,8 +175,12 @@ class ApproximateEqualOp : public XlaOpKernel { // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* b = ctx->builder(); - auto result = b->Lt(b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))), - XlaHelpers::FloatLiteral(b, input_type(0), tolerance_)); + auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))); + auto abs_shape = b->GetShape(abs); + OP_REQUIRES_OK(ctx, abs_shape.status()); + auto abs_type = abs_shape.ValueOrDie()->element_type(); + auto result = b->Lt( + abs, b->ConvertElementType(b->ConstantR0(tolerance_), abs_type)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 2331520230176fce7646d89140851fe37aee5fda..43a6a747c6bcc441f33f276fde4a66f367d99731 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -40,6 +41,11 @@ class CastOp : public XlaOpKernel { output = input; } else if (dst_dtype_ == DT_BOOL) { output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_)); + } else if (xla::primitive_util::IsComplexType(src_type_) && + !xla::primitive_util::IsComplexType(dst_type_)) { + // As in cast_op.h, we replicate the numpy behavior of truncating the + // imaginary part. + output = builder->ConvertElementType(builder->Real(input), dst_type_); } else { output = builder->ConvertElementType(input, dst_type_); } diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 0091b66d28ad62fcd5c0f3b09e90fed8347bb661..885f716afafca7ba23770e38f6693eed1ba50982 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -179,8 +179,10 @@ class ConvOp : public XlaOpKernel { xla::ConvolutionDimensionNumbers dims; std::vector window_strides; - dims.set_batch_dimension(GetTensorBatchDimIndex(num_dims(), data_format_)); - dims.set_feature_dimension(feature_dim); + dims.set_input_batch_dimension(batch_dim); + dims.set_output_batch_dimension(batch_dim); + dims.set_input_feature_dimension(feature_dim); + dims.set_output_feature_dimension(feature_dim); for (int i = 0; i < num_spatial_dims_; ++i) { int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); dims.add_spatial_dimensions(input_dim); @@ -285,8 +287,10 @@ class ConvBackpropInputOp : public XlaOpKernel { // comment at the top of conv_grad_ops.h for details. xla::ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(batch_dim); - dnums.set_feature_dimension(feature_dim); + dnums.set_input_batch_dimension(batch_dim); + dnums.set_output_batch_dimension(batch_dim); + dnums.set_input_feature_dimension(feature_dim); + dnums.set_output_feature_dimension(feature_dim); // TF filter shape is [ H, W, ..., inC, outC ] // Transpose the input and output features for computing the gradient. @@ -419,8 +423,10 @@ class ConvBackpropFilterOp : public XlaOpKernel { // Each spatial entry has size in_depth * batch // Swap n_dim and c_dim in the activations. - dnums.set_batch_dimension(c_dim); - dnums.set_feature_dimension(n_dim); + dnums.set_input_batch_dimension(c_dim); + dnums.set_output_batch_dimension(c_dim); + dnums.set_input_feature_dimension(n_dim); + dnums.set_output_feature_dimension(n_dim); // The gradients become the RHS of the convolution. // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index dde7898015e73190c96fa6effddfd3fc892264ea..7349dcb987cd88c423570889c0502d1a0bd12c52 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -199,6 +199,7 @@ class DynamicStitchOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("DynamicStitch"), DynamicStitchOp); +REGISTER_XLA_OP(Name("ParallelDynamicStitch"), DynamicStitchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 2c7d445600bf81873b0ed44fc8dccb23dcc902d6..e420f21ca33fe7de9b33f404ce04eae62d9c041e 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -30,7 +30,7 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( XlaOpKernelContext* context, const xla::ComputationDataHandle& input, const TensorShape& input_shape, const xla::ComputationDataHandle& indices, const TensorShape& indices_shape, int64 axis, DataType dtype, - xla::ComputationBuilder* builder) { + DataType index_type, xla::ComputationBuilder* builder) { // Although the indices Tensor is flattened into rank 1 during the lookup, // and each scalar entry is used as an index into the first dimension of the // input, the output is returned with shape: @@ -80,22 +80,23 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Specify the shape of the loop-carried Tensor tuple. xla::PrimitiveType ptype; TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); + xla::PrimitiveType idxtype; + TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &idxtype)); std::vector tuple_shapes( {// The iteration counter i is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), + xla::ShapeUtil::MakeShape(idxtype, {}), // The input array has shape input_shape. Loop invariant. xla::ShapeUtil::MakeShape(ptype, input_shape.dim_sizes()), // The gather indices are reshaped to rank 1. Loop invariant. - xla::ShapeUtil::MakeShape(xla::S32, {num_indices}), + xla::ShapeUtil::MakeShape(idxtype, {num_indices}), // The output array is rank >= 3, and is updated on each loop iteration. xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())}); xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); // Construct the initial values of the loop-carried Tensors. - auto init_i = builder->ConstantR0(0); - auto init_out = - builder->Broadcast(builder->ConstantLiteral(xla::Literal::Zero(ptype)), - loop_out_shape.dim_sizes()); + auto init_i = XlaHelpers::Zero(builder, index_type); + auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), + loop_out_shape.dim_sizes()); // Flatten the indices into 1-D for ease of iteration. auto indices_1d = builder->Reshape(indices, {num_indices}); auto init = builder->Tuple({init_i, input, indices_1d, init_out}); @@ -105,7 +106,7 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( "GatherWhileCond"); condb.Lt(condb.GetTupleElement( condb.Parameter(0, tuple_shape, "GatherWhileTuple"), 0), - condb.ConstantR0(num_indices)); + XlaHelpers::IntegerLiteral(&condb, index_type, num_indices)); auto cond_status = condb.Build(); auto cond = cond_status.ConsumeValueOrDie(); @@ -127,7 +128,7 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Slice from the input array. auto index = bodyb.DynamicSlice(indices, bodyb.Reshape(i, {1}), {1}); auto start_indices = bodyb.Pad( - bodyb.Reshape(index, {1}), bodyb.ConstantR0(0), + bodyb.Reshape(index, {1}), XlaHelpers::Zero(&bodyb, index_type), xla::MakeEdgePaddingConfig( {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}})); auto slice_i = bodyb.Reshape( @@ -136,7 +137,8 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Construct the index into the R3+ output Tensor 0, ..., , 0, ... std::vector out_index_vals( - loop_out_shape.dims(), bodyb.ConstantR1({0})); + loop_out_shape.dims(), + bodyb.Reshape(XlaHelpers::Zero(&bodyb, index_type), {1})); out_index_vals[input_shape_pre_axis.dims() + extra_dims] = bodyb.Reshape(i, {1}); auto out_index = bodyb.ConcatInDim(out_index_vals, 0); @@ -144,8 +146,8 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Update the output Tensor auto updated_output = bodyb.DynamicUpdateSlice(output, slice_i, out_index); - bodyb.Tuple({bodyb.Add(i, bodyb.ConstantR0(1)), input, indices, - updated_output}); + bodyb.Tuple({bodyb.Add(i, XlaHelpers::One(&bodyb, index_type)), input, + indices, updated_output}); } auto body_status = bodyb.Build(); auto body = body_status.ConsumeValueOrDie(); @@ -156,124 +158,6 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( return builder->Reshape(gather_output, out_shape.dim_sizes()); } -namespace { - -class GatherOpCustomCall : public XlaOpKernel { - public: - explicit GatherOpCustomCall(OpKernelConstruction* context) - : XlaOpKernel(context) {} - - void Compile(XlaOpKernelContext* context) override { - const TensorShape params_shape = context->InputShape(0); - const auto params_dims = params_shape.dims(); - const TensorShape indices_shape = context->InputShape(1); - OP_REQUIRES( - context, TensorShapeUtils::IsVectorOrHigher(params_shape), - errors::InvalidArgument("params must be at least 1 dimensional")); - - DataType index_type = input_type(1); - OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("index must be int32 or int64")); - - // GatherV2 added an axis argument. We support both Gather and GatherV2 in - // this kernel by defaulting axis to 0 if there are 2 inputs. - int64 axis = 0; - if (context->num_inputs() == 3) { - const TensorShape axis_shape = context->InputShape(2); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), - errors::InvalidArgument("axis must be scalar")); - DataType axis_type = input_type(2); - OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, - errors::InvalidArgument("axis must be int32 or int64")); - - xla::Literal literal; - OP_REQUIRES_OK(context, context->ConstantInput(2, &literal)); - int64 axis_input = axis_type == DT_INT32 ? literal.Get({}) - : literal.Get({}); - axis = axis_input < 0 ? axis_input + params_dims : axis_input; - OP_REQUIRES(context, 0 <= axis && axis < params_dims, - errors::InvalidArgument("Expected axis in the range [", - -params_dims, ", ", params_dims, - "), but got ", axis_input)); - } - - // Check that we have enough index space. - const int64 limit = index_type == DT_INT32 - ? std::numeric_limits::max() - : std::numeric_limits::max(); - OP_REQUIRES(context, params_shape.dim_size(axis) <= limit, - errors::InvalidArgument( - "params.shape[", axis, "] too large for ", - DataTypeString(index_type), - " indexing: ", params_shape.dim_size(axis), " > ", limit)); - - // The result shape is params.shape[0:axis] + indices.shape + - // params.shape[axis + 1:]. - TensorShape result_shape; - int64 outer_size = 1; - int64 inner_size = 1; - for (int i = 0; i < axis; i++) { - result_shape.AddDim(params_shape.dim_size(i)); - outer_size *= params_shape.dim_size(i); - } - result_shape.AppendShape(indices_shape); - for (int i = axis + 1; i < params_dims; i++) { - result_shape.AddDim(params_shape.dim_size(i)); - inner_size *= params_shape.dim_size(i); - } - - XlaContext& tc = XlaContext::Get(context); - OP_REQUIRES( - context, tc.allow_cpu_custom_calls(), - errors::InvalidArgument("Gather op requires CustomCall on CPU")); - - xla::ComputationBuilder& b = *context->builder(); - - // Call gather_xla_float_kernel (from gather_op_kernel_float.cc). - // XLA passes to the function, so it is not included here. - std::vector args; - args.push_back(tc.GetOrCreateRuntimeContextParameter()); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR0(indices_shape.num_elements()))); - args.push_back( - b.ConstantLiteral(*xla::Literal::CreateR0(outer_size))); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR0(params_shape.dim_size(axis)))); - args.push_back( - b.ConstantLiteral(*xla::Literal::CreateR0(inner_size))); - args.push_back(context->Input(0)); - args.push_back(context->Input(1)); - - xla::Shape xla_out_shape; - OP_REQUIRES_OK( - context, TensorShapeToXLAShape(DT_FLOAT, result_shape, &xla_out_shape)); - - // Call the custom code with args: - xla::ComputationDataHandle output; - if (index_type == DT_INT32) { - output = b.CustomCall("gather_float_int32_xla_impl", args, xla_out_shape); - } else { - output = b.CustomCall("gather_float_int64_xla_impl", args, xla_out_shape); - } - - context->SetOutput(0, output); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(GatherOpCustomCall); -}; - -REGISTER_XLA_OP(Name("Gather") - .TypeConstraint("Tparams", DT_FLOAT) - .Device(DEVICE_CPU_XLA_JIT), - GatherOpCustomCall); -REGISTER_XLA_OP(Name("GatherV2") - .TypeConstraint("Tparams", DT_FLOAT) - .Device(DEVICE_CPU_XLA_JIT), - GatherOpCustomCall); - -} // namespace - GatherOpDynamicSlice::GatherOpDynamicSlice(OpKernelConstruction* context) : XlaOpKernel(context) {} @@ -303,20 +187,17 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) { ", ", params_dims, "), but got ", axis)); } - xla::ComputationDataHandle gather = - XlaComputeGatherDynamicSlice(context, input, input_shape, indices, - indices_shape, axis, DT_FLOAT, builder); + DataType index_type = input_type(1); + OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, + errors::InvalidArgument("indices must be int32 or int64")); + + xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( + context, input, input_shape, indices, indices_shape, axis, input_type(0), + index_type, builder); context->SetOutput(0, gather); } -REGISTER_XLA_OP(Name("Gather") - .TypeConstraint("Tparams", DT_FLOAT) - .Device(DEVICE_GPU_XLA_JIT), - GatherOpDynamicSlice); - -REGISTER_XLA_OP(Name("GatherV2") - .TypeConstraint("Tparams", DT_FLOAT) - .Device(DEVICE_GPU_XLA_JIT), - GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("GatherV2"), GatherOpDynamicSlice); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index 5623c4d1c2be18b73696052c7fad822355a30f77..2c80395c56d73adad7dc1679ba6423fbe103605a 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -28,11 +28,13 @@ namespace tensorflow { // Adds to builder an XLA computation that performs a gather on input (of // shape input_shape) keyed on indices (of shape indices_shape). +// +// index_type must be must be DT_INT32 or DT_INT64. xla::ComputationDataHandle XlaComputeGatherDynamicSlice( XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, const TensorShape& input_shape, const xla::ComputationDataHandle& indices, const TensorShape& indices_shape, int64 axis, DataType dtype, - xla::ComputationBuilder* builder); + DataType index_type, xla::ComputationBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc deleted file mode 100644 index 33b1b087d00d8263cd80f7d5d879401e4ed6c0fb..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/gather_functor.h" -#include "tensorflow/core/platform/dynamic_annotations.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { - -EIGEN_STRONG_INLINE void gather_float_int32_xla_impl(float* out, void** data) { - // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 7 * sizeof(void*)); - - int64 indices_size = *static_cast(data[1]); - int64 params_x = *static_cast(data[2]); - int64 params_y = *static_cast(data[3]); - int64 params_z = *static_cast(data[4]); - - float* in = static_cast(data[5]); - - int32* indices = static_cast(data[6]); - Eigen::DSizes in_eig_sizes; - in_eig_sizes[0] = params_x; - in_eig_sizes[1] = params_y; - in_eig_sizes[2] = params_z; - tensorflow::TTypes::ConstTensor in_eig(in, in_eig_sizes); - - Eigen::DSizes indices_eig_sizes; - indices_eig_sizes[0] = indices_size; - tensorflow::TTypes::ConstFlat indices_eig(indices, indices_eig_sizes); - - Eigen::DSizes out_eig_sizes; - out_eig_sizes[0] = params_x; - out_eig_sizes[1] = indices_size; - out_eig_sizes[2] = params_z; - tensorflow::TTypes::Tensor out_eig(out, out_eig_sizes); - - tensorflow::functor::GatherFunctorCPU f; - const int64 bad_i = f(in_eig, indices_eig, out_eig); - if (bad_i != -1) { - tensorflow::XlaLocalRuntimeContext* runtime_context = - static_cast(data[0]); - runtime_context->error = true; - runtime_context->error_msg = "Invalid index for gather"; - for (int i = 0; i < out_eig.size(); ++i) out[i] = 0; - } -} - -} // namespace tensorflow - -// Implements gather on CPU. This is called by an XLA custom call, set up by -// gather_op.cc. -extern "C" void TF_EXPORT gather_float_int32_xla_impl(float* out, void** data) { - tensorflow::gather_float_int32_xla_impl(out, data); -} diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc deleted file mode 100644 index 5e2d872ce0b28ab479c73ed1fea5f32804c21e22..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/gather_functor.h" -#include "tensorflow/core/platform/dynamic_annotations.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { - -EIGEN_STRONG_INLINE void gather_float_int64_xla_impl(float* out, void** data) { - // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 7 * sizeof(void*)); - - int64 indices_size = *static_cast(data[1]); - int64 params_x = *static_cast(data[2]); - int64 params_y = *static_cast(data[3]); - int64 params_z = *static_cast(data[4]); - - float* in = static_cast(data[5]); - - int64* indices = static_cast(data[6]); - Eigen::DSizes in_eig_sizes; - in_eig_sizes[0] = params_x; - in_eig_sizes[1] = params_y; - in_eig_sizes[2] = params_z; - tensorflow::TTypes::ConstTensor in_eig(in, in_eig_sizes); - - Eigen::DSizes indices_eig_sizes; - indices_eig_sizes[0] = indices_size; - tensorflow::TTypes::ConstFlat indices_eig(indices, indices_eig_sizes); - - Eigen::DSizes out_eig_sizes; - out_eig_sizes[0] = params_x; - out_eig_sizes[1] = indices_size; - out_eig_sizes[2] = params_z; - tensorflow::TTypes::Tensor out_eig(out, out_eig_sizes); - - tensorflow::functor::GatherFunctorCPU f; - const int64 bad_i = f(in_eig, indices_eig, out_eig); - if (bad_i != -1) { - tensorflow::XlaLocalRuntimeContext* runtime_context = - static_cast(data[0]); - runtime_context->error = true; - runtime_context->error_msg = "Invalid index for gather"; - for (int i = 0; i < out_eig.size(); ++i) out[i] = 0; - } -} - -} // namespace tensorflow - -// Implements gather on CPU. This is called by an XLA custom call, set up by -// gather_op.cc. -extern "C" void TF_EXPORT gather_float_int64_xla_impl(float* out, void** data) { - tensorflow::gather_float_int64_xla_impl(out, data); -} diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index db7d556630a04d93a7eee308117dd429b8af26d1..b8769b3ea2be0a791d9c3e5e7acd8b6184442af2 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -82,16 +83,24 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); // Compute a mask that has 1s for elements equal to the maximum. - xla::ComputationDataHandle mask = b->ConvertElementType( + xla::ComputationDataHandle partial_mask = b->ConvertElementType( b->Eq(input, input_max, broadcast_dims), xla_index_type); - // Multiply by the vector [0, 1, 2, ...] to convert each 1 into its index. - // TODO(phawkins): add a bitwise And operator to HLO, use a bitwise and - // instead of a multiplication here. + // In order to make identity elements for a bitwise And, we: + // Left shift the 1 to the leftmost bit, yielding 0x10...0 + // Arithmetic right shift the 1 back to the rightmost bit, yielding 0xFF...F + int32 bits_in_type = + xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_index_type) * 8 - 1; + xla::ComputationDataHandle shift_amount = + XlaHelpers::IntegerLiteral(b, index_type, bits_in_type); + xla::ComputationDataHandle full_mask = b->ShiftRightArithmetic( + b->ShiftLeft(partial_mask, shift_amount), shift_amount); + + // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its index. xla::ComputationDataHandle iota; OP_REQUIRES_OK(ctx, XlaHelpers::Iota(b, index_type, axis_size, &iota)); xla::ComputationDataHandle product = - b->Mul(mask, iota, /*broadcast_dimensions=*/{axis}); + b->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); // If there are multiple maximum elements, choose the one with the highest // index. diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index afbd64ca5038378d48744d6d773e0dfb1376e1f9..47cf8c6675bc120653c2a5ab6d4b07376dc382ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/macros.h" @@ -47,3 +48,5 @@ EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) { tensorflow::argmax_float_1d_xla_impl(out, data); } + +REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc index 841ff2f4df79fdd790ee3aace9e38aaeb01a3080..9b83392d8fbe461970603fbadee76e8d71b1ebd0 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/macros.h" @@ -49,3 +50,5 @@ EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) { extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) { tensorflow::argmax_float_2d_xla_impl(out, data); } + +REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl); diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 5c799a0e4f86db04dc966411e0c917387186ce59..fcef497e5845d9080bc83b54e92dcf2fdecf5f12 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -23,6 +23,9 @@ limitations under the License. namespace tensorflow { namespace { +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; + class MatMulOp : public XlaOpKernel { public: explicit MatMulOp(OpKernelConstruction* ctx, bool is_sparse = false) @@ -73,7 +76,7 @@ class MatMulOp : public XlaOpKernel { bool transpose_b_; }; -REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kFloatTypes), MatMulOp); +REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kMatmulTypes), MatMulOp); class SparseMatMulOp : public MatMulOp { public: diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 66b99665cbefd9ffd2acabe6eb296f485ca6a59d..2421825ead17a3acee9f145f00904d382fb656f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -140,7 +140,7 @@ class TruncatedNormalOp : public XlaOpKernel { xla::ComputationBuilder* b) { xla::ComputationDataHandle too_large = b->Gt(candidate, two_sd(false, b)); xla::ComputationDataHandle too_small = b->Lt(candidate, two_sd(true, b)); - return b->LogicalOr(too_large, too_small); + return b->Or(too_large, too_small); }; // The algorithm we're using is roughly: diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index dae2eb9d2a92ef8d4eabb8d6f9a79758c42d446d..647b6274083cf8886af6c451b746416445a4a2b2 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -129,7 +129,7 @@ class AllOp : public XlaReductionOp { void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { - builder->LogicalAnd(scalar_lhs, scalar_rhs); + builder->And(scalar_lhs, scalar_rhs); } }; @@ -147,7 +147,7 @@ class AnyOp : public XlaReductionOp { void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { - builder->LogicalOr(scalar_lhs, scalar_rhs); + builder->Or(scalar_lhs, scalar_rhs); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index a137d28118e6b4c66c70253817be9b3f0b75088a..12a35529992e6160566046dd28f9321c88afec91 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -77,9 +77,9 @@ class Relu6GradOp : public XlaOpKernel { b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); const auto six = b->Broadcast( XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); - auto out = b->Select( - b->LogicalAnd(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), - ctx->Input(0), zero); + auto out = + b->Select(b->And(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), + ctx->Input(0), zero); ctx->SetOutput(0, out); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index ed818c56ed0e6fa41374234d6f6712a2bbda94e2..5172781c0d05b6682fe92086654e3b86961949ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index a0d8ab4d73f7491fe96299c6cdc918f00a3d7a97..750a4c2dec8154f97f307978b3d8884271292279 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -202,7 +202,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { // NaN otherwise; then add that vector to the labels to force out-of-range // values to NaNs. xla::ComputationDataHandle nan_or_zero = builder->Select( - builder->LogicalAnd( + builder->And( builder->Le(XlaHelpers::Zero(builder, indices_type), indices), builder->Lt(indices, XlaHelpers::IntegerLiteral( builder, indices_type, depth))), diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index e2d3d40813b63360d37392b59b3c68cc478b077b..351fda251798e43b607fb445f2c98abd57b3d86b 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -307,11 +307,12 @@ class TensorArrayGatherOp : public XlaOpKernel { OP_REQUIRES(ctx, indices_shape.dims() == 1, errors::InvalidArgument("indices must be rank 1")); auto indices = ctx->Input(1); + DataType index_type = ctx->input_type(1); xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, b); + ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b); ctx->SetOutput(0, gather); } diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 82ae0df5cc501cf1b51c2b25b9330d582fbdc44c..5534d1bfa1338c7fe3647cd6aa281c4907dfdf8c 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -37,8 +37,9 @@ class ResourceApplyGradientDescent : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); } }; -REGISTER_XLA_OP(Name("ResourceApplyGradientDescent"), - ResourceApplyGradientDescent); +REGISTER_XLA_OP( + Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes), + ResourceApplyGradientDescent); class ResourceApplyMomentum : public XlaOpKernel { public: @@ -109,7 +110,8 @@ class ResourceApplyMomentum : public XlaOpKernel { private: bool use_nesterov_; }; -REGISTER_XLA_OP(Name("ResourceApplyMomentum"), ResourceApplyMomentum); +REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes), + ResourceApplyMomentum); class ResourceApplyAdagrad : public XlaOpKernel { public: @@ -163,7 +165,8 @@ class ResourceApplyAdagrad : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); } }; -REGISTER_XLA_OP(Name("ResourceApplyAdagrad"), ResourceApplyAdagrad); +REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes), + ResourceApplyAdagrad); class ResourceApplyAdam : public XlaOpKernel { public: @@ -263,7 +266,8 @@ class ResourceApplyAdam : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ResourceApplyAdam"), ResourceApplyAdam); +REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes), + ResourceApplyAdam); class ResourceApplyRMSProp : public XlaOpKernel { public: @@ -362,7 +366,8 @@ class ResourceApplyRMSProp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom)); } }; -REGISTER_XLA_OP(Name("ResourceApplyRMSProp"), ResourceApplyRMSProp); +REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes), + ResourceApplyRMSProp); void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { @@ -500,7 +505,8 @@ class ResourceApplyFtrl : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ResourceApplyFtrl"), ResourceApplyFtrl); +REGISTER_XLA_OP(Name("ResourceApplyFtrl").TypeConstraint("T", kFloatTypes), + ResourceApplyFtrl); class ResourceApplyFtrlV2 : public XlaOpKernel { public: @@ -515,7 +521,8 @@ class ResourceApplyFtrlV2 : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ResourceApplyFtrlV2"), ResourceApplyFtrlV2); +REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes), + ResourceApplyFtrlV2); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 3e4a0f5950059b577eb470e5623cabd1d4ca6be9..a266e9013c41b88788dbc99849f01c09f3d61348 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -41,6 +41,12 @@ namespace { }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op); +XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x)); + +XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x))); + +XLAJIT_MAKE_UNARY(Conj, b->Conj(x)); + // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); @@ -87,7 +93,8 @@ XLAJIT_MAKE_UNARY(Log, b->Log(x)); // TODO(b/34703906): use a more accurate implementation of log1p. XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); -XLAJIT_MAKE_UNARY(LogicalNot, b->LogicalNot(x)); +XLAJIT_MAKE_UNARY(Invert, b->Not(x)); +XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); // Implements Banker's rounding: numbers that are equidistant between two @@ -104,9 +111,9 @@ static xla::ComputationDataHandle Round(xla::ComputationBuilder* b, auto nearest_even_int = b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x)))); auto is_odd = b->Eq(nearest_even_int, one); - return b->Select(b->LogicalOr(b->Gt(fraction, half), - b->LogicalAnd(b->Eq(fraction, half), is_odd)), - b->Add(round_val, one), round_val); + return b->Select( + b->Or(b->Gt(fraction, half), b->And(b->Eq(fraction, half), is_odd)), + b->Add(round_val, one), round_val); } XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x)); @@ -161,6 +168,9 @@ XLAJIT_MAKE_UNARY(Square, b->Mul(x, x)); XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x))); XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x)); +XLAJIT_MAKE_UNARY(Real, b->Real(x)); +XLAJIT_MAKE_UNARY(Imag, b->Imag(x)); + #undef XLAJIT_MAKE_UNARY } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 4ae983854780e49d8c7af29f218d3fbb7db225e2..b19ea22f50d2dd44e8d1d81f5930263f364030e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -111,9 +111,10 @@ class ResourceGatherOp : public XlaOpKernel { auto indices = ctx->Input(1); auto indices_shape = ctx->InputShape(1); + DataType index_type = ctx->input_type(1); xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( ctx, resource_handle, resource_shape, indices, indices_shape, 0, - resource_dtype, builder); + resource_dtype, index_type, builder); ctx->SetOutput(0, gather); } }; diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc index c1005405f9a9b09e4a6480332861d0cce2c52291..4a669f8e6eaf644f119f3c0a66f29d9f2c9a9d16 100644 --- a/tensorflow/compiler/tf2xla/ops/functional_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/functional_ops.cc @@ -34,14 +34,41 @@ output = input; While (Cond(output)) { output = Body(output) } input: A list of input tensors whose types are T. output: A list of output tensors whose types are T. cond: A function takes 'input' and returns a tensor. If the tensor is - a scalar of non-boolean, the scalar is converted to a boolean - according to the following rule: if the scalar is a numerical - value, non-zero means True and zero means False; if the scalar is - a string, non-empty means True and empty means False. If the - tensor is not a scalar, non-emptiness means True and False - otherwise. + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. body: A function that takes a list of tensors and returns another list of tensors. Both lists have the same types as specified by T. )doc"); +// TODO(b/37549631) setting the If Op to always be stateful is too +// conservative. +REGISTER_OP("XlaIf") + .Input("cond: Tcond") + .Input("inputs: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("then_branch: func") + .Attr("else_branch: func") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = cond ? then_branch(inputs) : else_branch(inputs). + +cond: A boolean scalar. +inputs: A list of input tensors. +output: A list of tensors returned by either then_branch(inputs) or + else_branch(inputs). The input shapes of the then_branch and + else_branch must match. +then_branch: A function takes 'inputs' and returns a list of tensors, + whose types are the same as what else_branch returns. +else_branch: A function takes 'inputs' and returns a list of tensors. + whose types are the same as what then_branch returns. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc index b6947bfe570c75dd0c7c6301b972e2012bae26bd..4b41c16a8b3fdc0c3412c76d29d3ec2b7bdfd0aa 100644 --- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc @@ -37,7 +37,14 @@ REGISTER_OP("_XLARecv") .Attr("tensor_name: string") .Attr("shape: shape") .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + TensorShape shape_attr; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); + c->set_output(0, s); + return Status::OK(); + }) .Doc(R"doc( Receives the named tensor from another XLA computation. diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index b7213a6cc1e4066f98523ec57681f4c0651f71b5..a14c93a2b9494b89f579bc20ee0510c136f8f01b 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -255,11 +255,10 @@ Status CreateXlaArgs(const Graph& graph, Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, xla::Computation* computation, bool* requires_runtime_context) { - // Create a device and context to convert the graph into an XLA computation. XlaOpRegistry::RegisterCompilationKernels(); - // Populate the context with args from the graph. for (Node* node : graph->nodes()) { - node->set_assigned_device_name(DEVICE_CPU_XLA_JIT); + node->set_assigned_device_name( + strings::StrCat("/device:", DEVICE_CPU_XLA_JIT)); } std::vector xla_args; TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index b54848f342406c9211c06664fd1f6c0783e0891f..1efbe0ffb17dad5332aa700b2e255d4a99fbef72 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -43,6 +43,12 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_UINT16: *type = xla::U16; return Status::OK(); + case tensorflow::DT_UINT32: + *type = xla::U32; + return Status::OK(); + case tensorflow::DT_UINT64: + *type = xla::U64; + return Status::OK(); case tensorflow::DT_HALF: *type = xla::F16; return Status::OK(); @@ -52,6 +58,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_DOUBLE: *type = xla::F64; return Status::OK(); + case tensorflow::DT_COMPLEX64: + *type = xla::C64; + return Status::OK(); case tensorflow::DT_QUINT8: *type = xla::U8; return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 8521d4167a11ebb7d8af87ca2e18e0140bc76eb9..e49663b8b047fb5f2c9ba17fa0aa032a673e7ed7 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -18,12 +18,15 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" +#include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" @@ -92,7 +95,6 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) } local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), - FunctionDefLibrary{})); local_pflr_.reset(new ProcessFunctionLibraryRuntime( &device_mgr_, Env::Default(), options.graph_def_version, @@ -127,6 +129,37 @@ static Status GetFunctionBody(const NameAttrList& function, return Status::OK(); } +Status XlaCompiler::FindFunctionBody(const NameAttrList& function, + const FunctionBody** fbody) { + // The function may be in either the local_flib_runtime_ or flib_runtime_. + // Look up the function in local first and if it is not found then look up the + // function in flib_runtime_. + auto status = GetFunctionBody(function, local_flib_runtime_, fbody); + if (!status.ok()) { + if (!errors::IsNotFound(status)) { + return status; + } + TF_RETURN_WITH_CONTEXT_IF_ERROR( + GetFunctionBody(function, flib_runtime_, fbody), + "Local lookup failed with: ", status.error_message()); + } + return Status::OK(); +} + +std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { + std::unique_ptr graph(new Graph(options_.flib_def)); + CopyGraph(*fbody->graph, graph.get()); + OptimizerOptions opts; + opts.set_do_common_subexpression_elimination(true); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); + GraphOptimizer optimizer(opts); + optimizer.Optimize(flib_runtime_, flib_runtime_->env(), + /*device=*/nullptr, &graph, /*shape_map=*/nullptr); + + return graph; +} + Status XlaCompiler::CompileFunction( const XlaCompiler::CompileOptions& options, const NameAttrList& function, const std::vector& args, @@ -142,11 +175,11 @@ Status XlaCompiler::CompileFunction( } const FunctionBody* fbody; - if (!GetFunctionBody(function, local_flib_runtime_, &fbody).ok()) { - TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_, &fbody)); - } + TF_RETURN_IF_ERROR(FindFunctionBody(function, &fbody)); - TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckSignature(fbody->arg_types, args), + "Signature check failure while compiling: ", function.name()); std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); @@ -181,7 +214,7 @@ namespace { Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, XlaCompilationDevice* device, FunctionLibraryRuntime* flib, int64 step_id) { - // Resource cleanup is a bit messy. XlaContext is a ref-counted resource; the + // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the // resource manager takes ownership via Create, and unrefs via Cleanup. We // explicitly add a reference to ensure the refcount at entry is maintained at // all exit points; Create and Cleanup are always called in this function. @@ -198,55 +231,12 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, step_container->name(), XlaContext::kXlaContextResourceName, xla_context)); - // Create a LocalExecutor that will own and run the graph. - // TODO(b/66947550): migrate away from using an Executor in order to guarantee - // determinism and thread-safety. - LocalExecutorParams exec_params; - exec_params.device = device; - exec_params.function_library = flib; - exec_params.create_kernel = [flib](const NodeDef& ndef, OpKernel** kernel) { - return flib->CreateKernel(ndef, kernel); - }; - exec_params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; - Executor* exec_ptr = nullptr; - TF_RETURN_IF_ERROR(NewLocalExecutor(exec_params, graph.release(), &exec_ptr)); - std::unique_ptr exec(exec_ptr); - // At this point ownership of the graph has been transferred to exec. - - // Run the graph symbolically, turning the graph into an XLA computation. - Executor::Args exec_args; - exec_args.step_id = step_id; - exec_args.step_container = step_container.get(); - - // Pushes closures to run onto `worklist`. We don't run the closures directly - // from 'runner' since that might lead to a stack overflow for large graphs. - std::deque worklist; - exec_args.runner = [&](Executor::Args::Closure c) { - worklist.push_back(std::move(c)); - }; - - // The following code assumes there is only one thread involved and no - // concurrency, because we did not provide Executor a threaded runner. Async - // ops on the XlaCompilation device must not use threads or concurrency - // internally. - bool done = false; - exec->RunAsync(exec_args, [&](const Status& s) { - status = s; - done = true; - }); - // Repeatedly run closures from the worklist until `done` is signalled. - while (!done) { - TF_RET_CHECK(!worklist.empty()); - Executor::Args::Closure& c = worklist.front(); - c(); - worklist.pop_front(); - } - TF_RETURN_WITH_CONTEXT_IF_ERROR( - status, "Conversion from TensorFlow graph to XLA computation failed."); - + GraphCompiler graph_compiler(xla_context, device, graph.get(), flib, + step_container.get()); + TF_RETURN_IF_ERROR(graph_compiler.Compile()); // Explicitly clean up the step container, to capture the cleanup status. step_container.reset(); - return status; + return Status::OK(); } // Builds XLA computations for each of the arguments to the computation. @@ -509,7 +499,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->requires_runtime_context = context->has_context_parameter(); // Tuple arguments and runtime context parameters are incompatible. - CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); + TF_RET_CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; @@ -546,7 +536,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, i < context->retvals().size(); ++i) { const XlaExpression& retval = context->retvals()[i]; if (!retval.has_constant_value()) { - CHECK_LT(computation_output, num_computation_outputs); + TF_RET_CHECK(computation_output < num_computation_outputs) + << "Computation has more outputs than expected"; OutputDescription& output = result->outputs[i]; output.is_constant = false; TF_RETURN_IF_ERROR(XLAShapeToTensorShape( diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 35159dbad4117895908584ad48878e2a989b9f40..a8882a638caf2d742bfa2b4f68140e1dc4520db1 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { - // The XlaCompiler class is responsible for compilation of a self-contained // subgraph of a TensorFlow computation using the XLA linear algebra runtime. // It does a symbolic execution of the graph starting from specific input @@ -136,6 +135,27 @@ class XlaCompiler { bool operator==(const Argument& other) const; }; + // Options pertaining to an individual call to CompileGraph() or + // CompileFunction(). + struct CompileOptions { + // If `use_tuple_arg` is true, a single tuple parameter will be used for all + // arguments; if false, each argument gets its own parameter. + bool use_tuple_arg = false; + + // If 'return_updated_values_for_all_resources' is true, then updated + // values of all resource arguments will be included in the + // 'resource_updates' of the computation, even if the resource was not + // modified by the computation. Used when compiling loop bodies to ensure + // the input and output signatures match. + bool return_updated_values_for_all_resources = false; + + // If 'resolve_compile_time_constants' is true, then outputs of a + // computation that are known to be compile-time constants will be returned + // as Tensors at compile-time, rather than as run-time outputs of the + // computation. + bool resolve_compile_time_constants = true; + }; + struct OutputDescription { // Type and shape of the output. DataType type; @@ -230,39 +250,9 @@ class XlaCompiler { }; explicit XlaCompiler(Options options); - ~XlaCompiler(); - - // Options pertaining to an individual call to CompileGraph() or - // CompileFunction(). - struct CompileOptions { - // If `use_tuple_arg` is true, a single tuple parameter will be used for all - // arguments; if false, each argument gets its own parameter. - bool use_tuple_arg = false; - - // If 'return_updated_values_for_all_resources' is true, then updated - // values of all resource resources arguments will be included in the - // 'resource_updates' of the computation, even if the resource was not - // modified by the computation. Used when compiling loop bodies to ensure - // the input and output signatures match. - bool return_updated_values_for_all_resources = false; - // If 'resolve_compile_time_constants' is true, then outputs of a - // computation that are known to be compile-time constants will be returned - // as Tensors at compile-time, rather than as run-time outputs of the - // computation. - bool resolve_compile_time_constants = true; - }; + ~XlaCompiler(); - // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation. - // `args` describes the arguments to the function, each of which must either - // be a runtime-parameter to the XLA computation, a compile-time constant, or - // a resource variable. Writes the compiled output to `result`. - // - // The generated XLA computation returns a tuple containing only the - // non-constant outputs as a function of the input arguments. Constant - // arguments are returned as host memory tensors in the output list and are - // not included in the XLA computation's outputs. The XLA computation is - // null if there are no data-dependent outputs and no side effects. Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, const std::vector& args, @@ -276,10 +266,17 @@ class XlaCompiler { const std::vector& args, CompilationResult* result); + Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func, + const std::vector& types, + const std::vector& shapes, + const std::vector& expressions, + std::vector* args); + // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. - // Channel handles can be used to communicate between different computations. - // Computations that communicate should be compiled with the same XlaCompiler. + // Channel handles can be used to communicate between different + // computations. Computations that communicate should be compiled with the + // same XlaCompiler. Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); const Options& options() const { return options_; } @@ -287,6 +284,18 @@ class XlaCompiler { FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } private: + // Sets the function body `fbody` to the one registered as `function`. + Status FindFunctionBody(const NameAttrList& function, + const FunctionBody** fbody); + + // Returns the optimized graph object in this function body. + std::unique_ptr GetGraph(const FunctionBody* fbody); + + // Graph compiler needs to know how to get an optimized graph from a function + // body. + friend class GraphCompiler; + friend class XlaCompilerTest; + Options options_; // Status set to non-OK in the constructor if initialization fails. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 531725a62335fc30086de2fe381591eb7d0976d0..93aae8485d157cd4afbf804d695d5c0ab8d7946c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" @@ -36,6 +38,37 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { + +class XlaCompilerTest : public ::testing::Test { + protected: + XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} + + void SetUp() override { + client_ = xla::ClientLibrary::LocalClientOrDie(); + + XlaOpRegistry::RegisterCompilationKernels(); + + FunctionDefLibrary flib; + flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); + } + + XlaCompiler::Options DefaultOptions() { + XlaCompiler::Options options; + options.device_type = &cpu_device_type_; + options.client = client_; + options.flib_def = flib_def_.get(); + return options; + } + + FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) { + return compiler->local_flib_def_.get(); + } + + DeviceType cpu_device_type_; + xla::Client* client_; + std::unique_ptr flib_def_; +}; + namespace { // Helper class to test the ability to pass resources through to XLA @@ -63,6 +96,7 @@ class DummyReadResourceOp : public XlaOpKernel { dummy->Unref(); ctx->SetOutput(0, ctx->Input(0)); + ctx->SetOutput(1, ctx->Input(0)); } }; @@ -80,22 +114,25 @@ class DummyReadResourceCC { if (!scope.ok()) return; scope.UpdateStatus(scope.DoShapeInference(ret)); if (!scope.ok()) return; - this->output_ = Output(ret, 0); + this->output1_ = Output(ret, 0); + this->output2_ = Output(ret, 1); } - Node* node() const { return output_.node(); } - Output output_; + Output output1_; + Output output2_; }; REGISTER_OP("DummyReadResource") .Input("input: int32") - .Output("output: int32") + .Output("output1: int32") + .Output("output2: int32") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( A dummy Op. input: dummy input. -output: dummy output. +output1: dummy output. +output2: dummy output. )doc"); REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp); @@ -125,31 +162,6 @@ REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT), REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT), DummyDuplicateOp); -class XlaCompilerTest : public ::testing::Test { - protected: - XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} - - void SetUp() override { - client_ = xla::ClientLibrary::LocalClientOrDie(); - - XlaOpRegistry::RegisterCompilationKernels(); - - FunctionDefLibrary flib; - flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); - } - - XlaCompiler::Options DefaultOptions() { - XlaCompiler::Options options; - options.device_type = &cpu_device_type_; - options.client = client_; - options.flib_def = flib_def_.get(); - return options; - } - - DeviceType cpu_device_type_; - xla::Client* client_; - std::unique_ptr flib_def_; -}; // Tests compilation and execution of an empty graph. TEST_F(XlaCompilerTest, EmptyReturnValues) { @@ -316,7 +328,8 @@ TEST_F(XlaCompilerTest, ResourceManager) { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto b = DummyReadResourceCC(scope.WithOpName("B"), a); - auto c = ops::_Retval(scope.WithOpName("C"), b.output_, 0); + auto c = ops::Add(scope.WithOpName("C"), b.output2_, b.output1_); + auto d = ops::_Retval(scope.WithOpName("D"), c, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(scope.ToGraph(graph.get())); @@ -349,6 +362,58 @@ TEST_F(XlaCompilerTest, ResourceManager) { resource->Unref(); } +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, DeterministicCompilation) { + // Builds a graph that contains a node with two output edges. The compiler + // should always traverse them in the same order. + const int64 test_count = 2; + + std::vector results(test_count); + + for (int64 i = 0; i < test_count; ++i) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::Neg(scope.WithOpName("B"), a); + auto c = ops::Neg(scope.WithOpName("C"), a); + auto d = ops::Add(scope.WithOpName("D"), b, c); + auto e = ops::_Retval(scope.WithOpName("E"), d, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the argument. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + + // Compiles the graph. + auto options = DefaultOptions(); + XlaCompiler compiler(options); + + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", + std::move(graph), args, &results[i])); + } + + for (int64 i = 1; i < test_count; ++i) { + auto m1 = + results[i - 1].computation->Snapshot().ValueOrDie()->entry().requests(); + auto m2 = + results[i].computation->Snapshot().ValueOrDie()->entry().requests(); + // Check if every entry is the same. + for (auto& entry1 : m1) { + int64 key = entry1.first; + auto value1 = entry1.second; + auto entry2 = m2.find(key); + auto value2 = entry2->second; + EXPECT_TRUE(entry2 != m2.end()); + string str1, str2; + value1.AppendToString(&str1); + value2.AppendToString(&str2); + EXPECT_EQ(str1, str2); + } + } +} + // Tests a computation that receives a TensorArray resource as input and // updates it. TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { @@ -489,5 +554,104 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { EXPECT_EQ(1, result.resource_updates.size()); } +// Tests CompileFunction with undefined function fails. +TEST_F(XlaCompilerTest, UndefinedFunctionFails) { + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + XlaCompiler::CompilationResult result; + NameAttrList name_attr; + name_attr.set_name("Function_NotDefined_"); + Status status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, + /*args=*/{}, &result); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + << status.error_message(); +} + +FunctionDef FillFn() { + return FunctionDefHelper::Define( + // Name + "FillFn", + // Args + {"x: T", "dims: int32"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + {{{"y"}, "Fill", {"dims", "x"}, {{"T", "$T"}}}}); +} + +TEST_F(XlaCompilerTest, FunctionCallWithConstants) { + // Certain operations in a function, "Fill" for example, requires the + // operator's argument to be a compile-time constant instead of a parameter. + // This testcase tests if XlaCompiler can handle such operators inside + // function calls. + XlaCompiler compiler(DefaultOptions()); + + FunctionDefLibrary flib; + *flib.add_function() = FillFn(); + + TF_ASSERT_OK(flib_def_->AddFunctionDef(FillFn())); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); + + NodeDef def; + TF_ASSERT_OK(NodeDefBuilder("fill", "FillFn", flib_def_.get()) + .Input(value.name(), 0, DT_INT32) + .Input(shape.name(), 1, DT_INT32) + .Finalize(&def)); + Status status; + Node* fill = scope.graph()->AddNode(def, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.DoShapeInference(fill)); + scope.graph()->AddEdge(value.node(), 0, fill, 0); + scope.graph()->AddEdge(shape.node(), 0, fill, 1); + + auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); + + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the argument. + std::vector args; + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", + std::move(graph), args, &result)); +} + +// Tests CompileFunction with a local function lookup failing, fails with +// informative error about both lookups. +TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { + XlaCompiler compiler(DefaultOptions()); + + auto local_flib_def = LocalFlibDef(&compiler); + TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo())); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + XlaCompiler::CompilationResult result; + NameAttrList name_attr; + name_attr.set_name("XTimesTwo"); + Status status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, + /*args=*/{}, &result); + + ASSERT_FALSE(status.ok()); + // Flib lookup failure. + EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + << status.error_message(); + // Local flib lookup failure. + EXPECT_TRUE( + StringPiece(status.error_message()).contains("Attr T is not found")) + << status.error_message(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f59b83cfdd778209935970981a1463d350a64be6..de5ad5f176536e1453da518b96ee755c7f1e8fdc 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -97,6 +97,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::F64: literal = *xla::Literal::CreateR0(value); break; + case xla::C64: + literal = *xla::Literal::CreateR0(value); + break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; case xla::S16: @@ -132,6 +135,9 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, case xla::F64: return b->ConstantR0(value); break; + case xla::C64: + return b->ConstantR0(value); + break; default: LOG(FATAL) << "unhandled element type " << type; } diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index 5bee68eefc8d9452b63113c080fc86d39550e899..6d49298a6f3e8a726695fafc42f3c5341fe98b5f 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -129,5 +129,19 @@ TEST(XlaJitCompiledCpuFunction, Sum) { EXPECT_TRUE(ShapeUtil::Compatible(result0, s32)); } +// Test when a graph compilation terminates early, resources are properly +// reclaimed. +TEST(XlaJitCompiledCpuFunction, SumWithJunkAttr) { + GraphDef graph_def = SumGraph(); + + (*graph_def.mutable_node(2)->mutable_attr())["junk"] = + TypeAttrValue(DT_INT32); + + tf2xla::Config config = SumConfig(); + EXPECT_FALSE(XlaJitCompiledCpuFunction::Compile(graph_def, config, + xla::ExecutableBuildOptions()) + .ok()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 1a8d03757a2b9bdee339ecb951a67528719314d4..6aee8c91cc01b4382ef867fa8e438eede008ac73 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -45,17 +45,19 @@ extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" extern const char* const DEVICE_XLA_CPU; extern const char* const DEVICE_XLA_GPU; -constexpr std::array kIntTypes = {{DT_INT32, DT_INT64}}; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE}}; -constexpr std::array kNumericTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE}}; +constexpr std::array kNumericTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64}}; -constexpr std::array kCpuAllTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kCpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_BOOL}}; -constexpr std::array kGpuAllTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kGpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_BOOL}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 6c4c970ce838400794e9fd4f3bddb829d8a14e5b..660f419e464936b01a3644e69c2f056f998140f5 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -7,7 +7,6 @@ package_group( packages = [ "//tensorflow/compiler/...", "//tensorflow/contrib/tpu/...", - "//tensorflow/contrib/xla_tf_graph/...", ], ) @@ -171,6 +170,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":status", + ":status_macros", ":types", ":xla_data_proto", "//tensorflow/core:lib", @@ -335,12 +335,32 @@ cc_library( ], ) +cc_library( + name = "array", + hdrs = ["array.h"], + deps = [ + ":types", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "array_test", + srcs = ["array_test.cc"], + deps = [ + ":array", + ":test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "array2d", srcs = ["array2d.cc"], hdrs = ["array2d.h"], visibility = ["//visibility:public"], deps = [ + ":array", ":types", ":util", "//tensorflow/core:lib", @@ -362,6 +382,7 @@ cc_library( hdrs = ["array3d.h"], visibility = [":friends"], deps = [ + ":array", ":types", "//tensorflow/core:lib", ], @@ -383,6 +404,7 @@ cc_library( hdrs = ["array4d.h"], visibility = [":friends"], deps = [ + ":array", ":array2d", ":types", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h new file mode 100644 index 0000000000000000000000000000000000000000..2aedafb91f1f737a369d5eeaf4b03c49803300ab --- /dev/null +++ b/tensorflow/compiler/xla/array.h @@ -0,0 +1,324 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_ARRAY_H_ +#define TENSORFLOW_COMPILER_XLA_ARRAY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// General N dimensional array class with arbitrary value type. +template +class Array { + public: + // Creates a new array with the specified dimensions. + explicit Array(const std::vector& sizes) : Array(sizes, T()) {} + + // Creates a new array with the specified dimensions and specified value for + // every cell. + Array(const std::vector& sizes, T value) + : sizes_(sizes), values_(new T[num_elements()]) { + Fill(value); + } + + // Creates a 2D array from the given nested initializer list. The outer + // initializer list is the first dimension, the inner is the second dimension. + // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. + Array(std::initializer_list> values) + : Array(ToInt64Vector({values.size(), values.begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + values_[idx] = it2; + ++idx; + } + } + CHECK(idx == num_elements()); + } + + // Creates a 3D array from the given nested initializer list. The outer + // initializer list is the first dimension, and so on. + Array(std::initializer_list>> + values) + : Array(ToInt64Vector({values.size(), values.begin()->size(), + values.begin()->begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + for (const auto& it3 : it2) { + values_[idx] = it3; + ++idx; + } + } + } + CHECK(idx == num_elements()); + } + + // Creates a 4D array from the given nested initializer list. The outer + // initializer list is the first dimension, and so on. + Array(std::initializer_list< + std::initializer_list>>> + values) + : Array(ToInt64Vector({values.size(), values.begin()->size(), + values.begin()->begin()->size(), + values.begin()->begin()->begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + for (const auto& it3 : it2) { + for (const auto& it4 : it3) { + values_[idx] = it4; + ++idx; + } + } + } + } + CHECK(idx == num_elements()); + } + + Array(const Array& other) + : sizes_(other.sizes_), values_(new T[num_elements()]) { + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + } + + Array& operator=(const Array& other) { + sizes_ = other.sizes_; + values_.reset(new T[num_elements()]); + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + return *this; + } + + // Fills the array with the specified value. + void Fill(const T& value) { + std::fill(&values_[0], &values_[0] + num_elements(), value); + } + + // Fills the array with sequentially increasing values. + void FillIota(const T& value) { + std::iota(&values_[0], &values_[0] + num_elements(), value); + } + + // Fills the array with the sequence i*multiplier for i=0,1,... + void FillWithMultiples(const T& multiplier) { + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = i * multiplier; + } + } + + // Fills the array with random normal variables with the specified mean. + void FillRandom(const T& value, const double mean = 0.0, + const int seed = 12345) { + std::mt19937 g(seed); + std::normal_distribution distribution(mean, + static_cast(value)); + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = static_cast(distribution(g)); + } + } + + // Sets all the values in the array to values specified in the container. + template > + void SetValues(const Container& container) { + CHECK_EQ(std::distance(std::begin(container), std::end(container)), + num_elements()); + std::copy(std::begin(container), std::end(container), &values_[0]); + } + + // Invokes a callback with the (indices, value_ptr) for each cell in the + // array. + void Each(std::function, T*)> f) { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + f(index, &values_[i]); + } + } + + // Invokes a callback with the (indices, value) for each cell in the array. + void Each( + std::function, T)> f) const { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + f(index, values_[i]); + } + } + + // Returns the value at the cell specified by the indexes. The number of + // arguments have to match with the number of dimensions for the array. + template + const T& operator()(Dims... dims) const { + // We are using a std::array to avoid having to allocate memory in this + // function for performance reasons. + std::array indexes{{static_cast(dims)...}}; + return values_[calculate_index(indexes)]; + } + + // Returns the value at the cell specified by the indexes. The number of + // arguments have to match with the number of dimensions for the array. + template + T& operator()(Dims... dims) { + // We are using a std::array to avoid having to allocate memory in this + // function for performance reasons. + std::array indexes{{static_cast(dims)...}}; + return values_[calculate_index(indexes)]; + } + + // Low-level accessor for stuff like memcmp, handle with care. Returns pointer + // to the underlying storage of the array (similarly to std::vector::data()). + T* data() const { + // TODO(tberghammer): Get rid of the const_cast. Currently it is needed + // because the Eigen backend needs a non-const pointers even for reading + // from the array. + return const_cast(this)->values_.get(); + } + + // Returns the size of the dimension at the given index. + int64 dim(int64 n) const { + CHECK(n < sizes_.size()); + return sizes_[n]; + } + + // Returns a vector containing the dimensions of the array. + const std::vector& dimensions() const { return sizes_; } + + int64 num_dimensions() const { return sizes_.size(); } + + // Returns the total number of elements in the array. + int64 num_elements() const { + return std::accumulate(sizes_.begin(), sizes_.end(), 1, + std::multiplies()); + } + + bool operator==(const Array& other) const { + if (sizes_.size() != other.sizes_.size()) { + return false; + } + for (int64 i = 0; i < sizes_.size(); ++i) { + if (sizes_[i] != other.sizes_[i]) { + return false; + } + } + for (int64 i = 0; i < num_elements(); ++i) { + if (values_[i] != other.values_[i]) { + return false; + } + } + return true; + } + + bool operator!=(const Array& other) const { return !(*this == other); } + + // Returns a string representation of the array suitable for debugging. + string ToString() const { + std::vector pieces; + std::vector index(sizes_.size()); + do { + // Emit leading spaces and opening square brackets + if (index.back() == 0) { + for (int64 i = sizes_.size() - 1; i >= 0; --i) { + if (i == 0 || index[i - 1] != 0) { + for (int64 j = 0; j < sizes_.size(); ++j) { + pieces.push_back(j < i ? " " : "["); + } + break; + } + } + } + + pieces.push_back( + tensorflow::strings::AlphaNum(values_[calculate_index(index)]) + .data()); + + // Emit comma if it isn't the last element + if (index.back() != sizes_.back() - 1) { + pieces.push_back(", "); + } + + // Emit closing square brackets + for (int64 i = sizes_.size() - 1; i >= 0; --i) { + if (index[i] != sizes_[i] - 1) { + break; + } + pieces.push_back("]"); + if (i != 0 && index[i - 1] != sizes_[i - 1] - 1) { + pieces.push_back(",\n"); + } + } + } while (next_index(&index)); + return tensorflow::str_util::Join(pieces, ""); + } + + private: + // Converts an initializer_list of type U to a vector of type int64. Used by + // the initializer list based constructors to convert the size type into int64 + // to be passed to the size based constructor. + template + static std::vector ToInt64Vector( + const std::initializer_list& data) { + return std::vector(data.begin(), data.end()); + } + + // Returns the linear index from the list of per-dimension indexes. Function + // is templated so can be used with an std::array from operator() to avoid + // memory allocation. + template + int64 calculate_index(const U& indexes) const { + CHECK_EQ(sizes_.size(), indexes.size()); + int64 index = 0; + for (int64 i = 0; i < sizes_.size(); ++i) { + index *= sizes_[i]; + index += indexes[i]; + } + return index; + } + + // Advances the specified set of indexes and returns true if we haven't + // wrapped around (i.e. result isn't {0, 0, ...}). + bool next_index(std::vector* index) const { + CHECK_EQ(index->size(), sizes_.size()); + for (int64 i = sizes_.size() - 1; i >= 0; --i) { + (*index)[i]++; + if ((*index)[i] < sizes_[i]) { + return true; + } + (*index)[i] = 0; + } + return false; + } + + std::vector sizes_; + std::unique_ptr values_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_ARRAY_H_ diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 2737764cbda87298599d7005c237a2093cbaba4a..bb85fbee9b97fd6b9b0bf7223a9b820989dcbfa7 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -34,93 +35,30 @@ limitations under the License. namespace xla { -// Simple 2D array structure. -// -// The data layout in major-to-minor order is: n1, n2. template -class Array2D { +class Array2D : public Array { public: - // Creates an empty array. - Array2D() : n1_(0), n2_(0) {} + Array2D() : Array(std::vector{0, 0}) {} - // Creates an array of dimensions n1 x n2, uninitialized values. Array2D(const int64 n1, const int64 n2) - : n1_(n1), n2_(n2), values_(new T[n1 * n2]()) { - Fill(T()); - } + : Array(std::vector{n1, n2}) {} - // Creates an array of dimensions n1 x n2, initialized to value. Array2D(const int64 n1, const int64 n2, const T value) - : n1_(n1), n2_(n2), values_(new T[n1 * n2]()) { - Fill(value); - } + : Array({n1, n2}, value) {} // Creates an array from the given nested initializer list. The outer // initializer list is the first dimension; the inner is the second dimension. // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. Array2D(std::initializer_list> values) - : Array2D(values.size(), values.begin()->size()) { - int64 n1 = 0; - for (auto n1_it = values.begin(); n1_it != values.end(); ++n1_it, ++n1) { - int64 n2 = 0; - for (auto n2_it = n1_it->begin(); n2_it != n1_it->end(); ++n2_it, ++n2) { - (*this)(n1, n2) = *n2_it; - } - } - } + : Array(values) {} - Array2D(const Array2D& other) : Array2D(other.n1(), other.n2()) { - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - } - - Array2D& operator=(const Array2D& other) { - n1_ = other.n1(); - n2_ = other.n2(); - values_.reset(new T[num_elements()]); - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - return *this; - } + Array2D(const Array2D& other) : Array(other) {} - T& operator()(const int64 i1, const int64 i2) { - CHECK_LT(i1, n1_); - CHECK_LT(i2, n2_); - return values_[i1 * n2_ + i2]; - } + int64 n1() const { return this->dim(0); } + int64 n2() const { return this->dim(1); } - const T& operator()(const int64 i1, const int64 i2) const { - CHECK_LT(i1, n1_); - CHECK_LT(i2, n2_); - return values_[i1 * n2_ + i2]; - } - - // Access to the array's dimensions. height() and width() provide the - // canonical interpretation of the array n1 x n2 having n1 rows of n2 columns - // each (height is number of rows; width is number of columns). - int64 n1() const { return n1_; } - int64 n2() const { return n2_; } - int64 height() const { return n1_; } - int64 width() const { return n2_; } - int64 num_elements() const { return n1_ * n2_; } - - // Low-level accessor for stuff like memcmp, handle with care. Returns pointer - // to the underlying storage of the array (similarly to std::vector::data()). - T* data() const { return const_cast(this)->values_.get(); } - - // Fills the array with the given value. - void Fill(const T& value) { - std::fill(&values_[0], &values_[0] + num_elements(), value); - } - - // Applies f to all cells in this array, in row-major order. - void Each(std::function f) { - for (int64 i0 = 0; i0 < n1(); ++i0) { - for (int64 i1 = 0; i1 < n2(); ++i1) { - f(i0, i1, &(*this)(i0, i1)); - } - } - } + int64 height() const { return this->dim(0); } + int64 width() const { return this->dim(1); } // Fills the array with a pattern of values of the form: // @@ -136,55 +74,14 @@ class Array2D { } } - // Fills the array with random normal variables of deviation value. - void FillRandom(const T& value, const double mean = 0.0, - const int seed = 12345) { - std::mt19937 g(seed); - std::normal_distribution distribution(mean, - static_cast(value)); - for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = static_cast(distribution(g)); - } - } - - // Returns a readable string representation of the array. - string ToString() const { - std::vector pieces = {"["}; - for (int64 row = 0; row < height(); ++row) { - pieces.push_back("["); - for (int64 col = 0; col < width(); ++col) { - pieces.push_back(tensorflow::strings::StrCat((*this)(row, col))); - pieces.push_back(", "); - } - pieces.pop_back(); - pieces.push_back("]"); - pieces.push_back(",\n "); - } - pieces.pop_back(); - pieces.push_back("]"); - return tensorflow::str_util::Join(pieces, ""); - } - - bool operator==(const Array2D& other) const { - if (n1() != other.n1() || n2() != other.n2()) { - return false; - } + // Applies f to all cells in this array, in row-major order. + void Each(std::function f) { for (int64 i0 = 0; i0 < n1(); ++i0) { for (int64 i1 = 0; i1 < n2(); ++i1) { - if ((*this)(i0, i1) != other(i0, i1)) { - return false; - } + f(i0, i1, &(*this)(i0, i1)); } } - return true; } - - bool operator!=(const Array2D& other) const { return !(*this == other); } - - private: - int64 n1_; - int64 n2_; - std::unique_ptr values_; }; // Returns a linspace-populated Array2D in the range [from, to] (inclusive) diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index 124ccd1975b3a9ab047e9bbbfb38921fe7386fe4..e9449f01ad69a5722f53cce09e2884e20a0def5a 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -32,22 +33,16 @@ limitations under the License. namespace xla { // Simple 3D array structure. -// -// The data layout in major-to-minor order is: n1, n2, n3. template -class Array3D { +class Array3D : public Array { public: // Creates an array of dimensions n1 x n2 x n3, uninitialized values. Array3D(const int64 n1, const int64 n2, const int64 n3) - : n1_(n1), n2_(n2), n3_(n3), values_(new T[n1 * n2 * n3]) { - Fill(T()); - } + : Array(std::vector{n1, n2, n3}) {} // Creates an array of dimensions n1 x n2 x n3, initialized to value. Array3D(const int64 n1, const int64 n2, const int64 n3, const T value) - : n1_(n1), n2_(n2), n3_(n3), values_(new T[n1 * n2 * n3]) { - Fill(value); - } + : Array(std::vector{n1, n2, n3}, value) {} // Creates an array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. @@ -58,84 +53,11 @@ class Array3D { // results in an array with n1=3, n2=4, n3=2. Array3D(std::initializer_list>> values) - : Array3D(values.size(), values.begin()->size(), - values.begin()->begin()->size()) { - int64 n1 = 0; - for (auto n1_it = values.begin(); n1_it != values.end(); ++n1_it, ++n1) { - int64 n2 = 0; - for (auto n2_it = n1_it->begin(); n2_it != n1_it->end(); ++n2_it, ++n2) { - int64 n3 = 0; - for (auto n3_it = n2_it->begin(); n3_it != n2_it->end(); - ++n3_it, ++n3) { - (*this)(n1, n2, n3) = *n3_it; - } - } - } - } + : Array(values) {} - Array3D(const Array3D& other) - : Array3D(other.n1(), other.n2(), other.n3()) { - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - } - - Array3D& operator=(const Array3D& other) { - n1_ = other.n1(); - n2_ = other.n2(); - n3_ = other.n3(); - values_.reset(new T[num_elements()]); - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - return *this; - } - - T& operator()(const int64 i1, const int64 i2, const int64 i3) { - CHECK_LT(i1, n1_); - CHECK_LT(i2, n2_); - CHECK_LT(i3, n3_); - return values_[i1 * n2_ * n3_ + i2 * n3_ + i3]; - } - - const T& operator()(const int64 i1, const int64 i2, const int64 i3) const { - CHECK_LT(i1, n1_); - CHECK_LT(i2, n2_); - CHECK_LT(i3, n3_); - return values_[i1 * n2_ * n3_ + i2 * n3_ + i3]; - } - - // Access to the array's dimensions. - int64 n1() const { return n1_; } - int64 n2() const { return n2_; } - int64 n3() const { return n3_; } - int64 num_elements() const { return n1_ * n2_ * n3_; } - - // Fills the array with the given value. - void Fill(const T& value) { - std::fill(&values_[0], &values_[0] + num_elements(), value); - } - - // Fills the array with sequentially increasing values. - void FillIota(const T& value) { - std::iota(&values_[0], &values_[0] + num_elements(), value); - } - - // Fills the array with random normal values with a mean of 0 and standard - // deviation of value. - void FillRandom(const T& value, const double mean = 0.0, - const int seed = 12345) { - std::mt19937 g(seed); - std::normal_distribution distribution(mean, - static_cast(value)); - for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = static_cast(distribution(g)); - } - } - - private: - int64 n1_; - int64 n2_; - int64 n3_; - std::unique_ptr values_; + int64 n1() const { return this->dim(0); } + int64 n2() const { return this->dim(1); } + int64 n3() const { return this->dim(2); } }; } // namespace xla diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index 4c7fce1aaf1faf4bd08bca38bc8eb2b47303b575..f8b2b2afe5fed9c465c2a1f39308b7f44311b16a 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -53,23 +54,15 @@ namespace xla { // more than one name is given above. See operator() for the exact // calculation of 1d indices from 4d indices. template -class Array4D { +class Array4D : public Array { public: // Creates a 4D array, uninitialized values. Array4D(int64 planes, int64 depth, int64 height, int64 width) - : planes_(planes), - depth_(depth), - height_(height), - width_(width), - values_(new T[planes * depth * height * width]) { - Fill(T()); - } + : Array(std::vector{planes, depth, height, width}) {} // Creates a 4D array, initialized to value. Array4D(int64 planes, int64 depth, int64 height, int64 width, T value) - : Array4D(planes, depth, height, width) { - Fill(value); - } + : Array(std::vector{planes, depth, height, width}, value) {} // Creates a 4D array, filled with values. // @@ -80,144 +73,26 @@ class Array4D { Array4D(int64 planes, int64 depth, int64 height, int64 width, const Container& values) : Array4D(planes, depth, height, width) { - SetValues(values); + this->SetValues(values); } // Construct an Array4D with the given nested initializer list. Array4D(std::initializer_list>>> values) - : Array4D(values.size(), values.begin()->size(), - values.begin()->begin()->size(), - values.begin()->begin()->begin()->size()) { - int64 plane = 0; - for (const auto values_in_plane : values) { - DCHECK_EQ(values_in_plane.size(), depth_); - int64 depth = 0; - for (const auto values_in_depth : values_in_plane) { - DCHECK_EQ(values_in_depth.size(), height_); - int64 height = 0; - for (const auto values_in_height : values_in_depth) { - DCHECK_EQ(values_in_height.size(), width_); - int64 width = 0; - for (const auto element_value : values_in_height) { - (*this)(plane, depth, height, width) = element_value; - ++width; - } - ++height; - } - ++depth; - } - ++plane; - } - } - - Array4D(const Array4D& other) - : Array4D(other.planes(), other.depth(), other.height(), other.width()) { - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - } - - Array4D& operator=(const Array4D& other) { - planes_ = other.planes(); - depth_ = other.depth(); - height_ = other.height(); - width_ = other.width(); - values_.reset(new T[num_elements()]); - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - return *this; - } - - T& operator()(int64 plane, int64 depth, int64 height, int64 width) { - CHECK_LT(plane, planes_); - CHECK_LT(depth, depth_); - CHECK_LT(height, height_); - CHECK_LT(width, width_); - return values_[plane * (depth_ * height_ * width_) + - depth * (height_ * width_) + height * (width_) + width]; - } - const T& operator()(int64 plane, int64 depth, int64 height, - int64 width) const { - return const_cast(this)->operator()(plane, depth, height, width); - } - - int64 width() const { return width_; } - int64 height() const { return height_; } - int64 depth() const { return depth_; } - int64 planes() const { return planes_; } + : Array(values) {} // Numerically-named aliases for the various dimensions. This matches the // dimension names used in array3d. - int64 n4() const { return width_; } - int64 n3() const { return height_; } - int64 n2() const { return depth_; } - int64 n1() const { return planes_; } - int64 num_elements() const { return width_ * height_ * depth_ * planes_; } - - // Sets all the values in the array to values. - template > - void SetValues(const Container& container) { - CHECK_EQ(std::distance(std::begin(container), std::end(container)), - num_elements()); - std::copy(std::begin(container), std::end(container), &values_[0]); - } - - // Fills the array with the given value. - void Fill(const T& value) { - std::fill(&values_[0], &values_[0] + num_elements(), value); - } + int64 n4() const { return this->dim(3); } + int64 n3() const { return this->dim(2); } + int64 n2() const { return this->dim(1); } + int64 n1() const { return this->dim(0); } - // Fills the array with iota. - void FillIota(const T& value) { - std::iota(&values_[0], &values_[0] + num_elements(), value); - } - - // Fills the array with random variable with a deviation of value and a mean - // of mean. - void FillRandom(const T& value, const double mean = 0.0, - const int seed = 12345) { - std::mt19937 g(seed); - std::normal_distribution distribution(mean, - static_cast(value)); - for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = static_cast(distribution(g)); - } - } - - // Fills values with the sequence i*multiplier for i=0,1,... - void FillWithMultiples(float multiplier) { - for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = i * multiplier; - } - } - - // Invokes a callback with the (indices, value_ptr) for each cell in the 4D - // array. - void Each(std::function, T*)> f) { - for (int64 plane = 0; plane < planes(); ++plane) { - for (int64 depth = 0; depth < this->depth(); ++depth) { - for (int64 height = 0; height < this->height(); ++height) { - for (int64 width = 0; width < this->width(); ++width) { - auto& value = (*this)(plane, depth, height, width); - f({plane, depth, height, width}, &value); - } - } - } - } - } - - // Invokes a callback with the (indices, value) for each cell in the 4D array. - void Each( - std::function, T)> f) const { - // We const_cast to be able to use the common non-const implementation, - // but prevent modification of the data by passing it by-value to the - // caller. - const_cast(this)->Each( - [&f](tensorflow::gtl::ArraySlice indices, T* value) { - f(indices, *value); - }); - } + int64 width() const { return this->dim(3); } + int64 height() const { return this->dim(2); } + int64 depth() const { return this->dim(1); } + int64 planes() const { return this->dim(0); } // Fills all of the {p,z} with the array provided, which specifies {y,x}. void FillWithYX(const Array2D& value) { @@ -267,38 +142,6 @@ class Array4D { } } } - - // Returns a string representation of the 4D array suitable for debugging. - string ToString() const { - std::vector pieces = { - tensorflow::strings::Printf("p=%lld,z=%lld,y=%lld,x=%lld {\n", planes(), - depth(), height(), width())}; - for (int64 plane = 0; plane < planes_; ++plane) { - pieces.push_back(" {\n"); - for (int64 depth = 0; depth < depth_; ++depth) { - pieces.push_back(" {\n"); - for (int64 height = 0; height < height_; ++height) { - pieces.push_back(" {"); - for (int64 width = 0; width < width_; ++width) { - pieces.push_back(tensorflow::strings::StrCat( - (*this)(plane, depth, height, width), ", ")); - } - pieces.push_back("},\n"); - } - pieces.push_back(" },\n"); - } - pieces.push_back(" },\n"); - } - pieces.push_back("}"); - return tensorflow::str_util::Join(pieces, ""); - } - - private: - int64 planes_; - int64 depth_; - int64 height_; - int64 width_; - std::unique_ptr values_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..093784f541b3bd18f4a1fc1b665cd0d17a892f28 --- /dev/null +++ b/tensorflow/compiler/xla/array_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/array.h" + +#include + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +TEST(ArrayTest, UninitializedDimsCtor) { + Array uninit({2, 3}); + EXPECT_EQ(uninit.num_dimensions(), 2); + EXPECT_EQ(uninit.dim(0), 2); + EXPECT_EQ(uninit.dim(1), 3); + EXPECT_EQ(uninit.num_elements(), 6); +} + +TEST(ArrayTest, FillCtor) { + Array fullof7({1, 2, 3}, 7); + + EXPECT_EQ(fullof7.dim(0), 1); + EXPECT_EQ(fullof7.dim(1), 2); + EXPECT_EQ(fullof7.dim(2), 3); + + for (int64 n0 = 0; n0 < fullof7.dim(0); ++n0) { + for (int64 n1 = 0; n1 < fullof7.dim(1); ++n1) { + for (int64 n2 = 0; n2 < fullof7.dim(2); ++n2) { + EXPECT_EQ(fullof7(n0, n1, n2), 7); + } + } + } +} + +TEST(ArrayTest, InitializerListCtor) { + Array arr({{1, 2, 3}, {4, 5, 6}}); + + EXPECT_EQ(arr.dim(0), 2); + EXPECT_EQ(arr.dim(1), 3); + + EXPECT_EQ(arr(0, 0), 1); + EXPECT_EQ(arr(0, 1), 2); + EXPECT_EQ(arr(0, 2), 3); + EXPECT_EQ(arr(1, 0), 4); + EXPECT_EQ(arr(1, 1), 5); + EXPECT_EQ(arr(1, 2), 6); +} + +TEST(ArrayTest, IndexingReadWrite) { + Array arr({2, 3}); + + EXPECT_EQ(arr(1, 1), 0); + EXPECT_EQ(arr(1, 2), 0); + arr(1, 1) = 51; + arr(1, 2) = 61; + EXPECT_EQ(arr(1, 1), 51); + EXPECT_EQ(arr(1, 2), 61); +} + +TEST(ArrayTest, IndexingReadWriteBool) { + Array arr{{false, true, false}, {false, true, false}}; + + EXPECT_EQ(arr(0, 1), true); + EXPECT_EQ(arr(0, 2), false); + arr(0, 1) = false; + arr(0, 2) = true; + EXPECT_EQ(arr(0, 1), false); + EXPECT_EQ(arr(0, 2), true); +} + +TEST(ArrayTest, Fill) { + Array fullof7({2, 3}, 7); + for (int64 n1 = 0; n1 < fullof7.dim(0); ++n1) { + for (int64 n2 = 0; n2 < fullof7.dim(1); ++n2) { + EXPECT_EQ(fullof7(n1, n2), 7); + } + } + + fullof7.Fill(11); + for (int64 n1 = 0; n1 < fullof7.dim(0); ++n1) { + for (int64 n2 = 0; n2 < fullof7.dim(1); ++n2) { + EXPECT_EQ(fullof7(n1, n2), 11); + } + } +} + +TEST(ArrayTest, DataPointer) { + Array arr{{1, 2, 3}, {4, 5, 6}}; + EXPECT_EQ(arr.data()[0], 1); +} + +TEST(ArrayTest, Stringification1D) { + Array arr({2}, 1); + const string expected = R"([1, 1])"; + EXPECT_EQ(expected, arr.ToString()); +} + +TEST(ArrayTest, Stringification2D) { + Array arr({2, 3}, 7); + const string expected = "[[7, 7, 7],\n [7, 7, 7]]"; + EXPECT_EQ(expected, arr.ToString()); +} + +TEST(ArrayTest, Stringification3D) { + Array arr({2, 3, 4}, 5); + const string expected = R"([[[5, 5, 5, 5], + [5, 5, 5, 5], + [5, 5, 5, 5]], + [[5, 5, 5, 5], + [5, 5, 5, 5], + [5, 5, 5, 5]]])"; + EXPECT_EQ(expected, arr.ToString()); +} + +TEST(ArrayTest, Each) { + Array arr({2, 3, 4}); + arr.FillWithMultiples(1); + + int64 each_count = 0, each_sum = 0; + arr.Each([&](tensorflow::gtl::ArraySlice idx, int cell) { + int64 lin_idx = idx[0] * 12 + idx[1] * 4 + idx[2]; + EXPECT_EQ(lin_idx, cell); + each_count++; + each_sum += cell; + }); + EXPECT_EQ(arr.num_elements(), each_count); + EXPECT_EQ(arr.num_elements() * (arr.num_elements() - 1) / 2, each_sum); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 2b142d933dbc8c5a7823f9426c423b59425a85bc..b6126981431dc9a3520b6c96321c453bc955e7c0 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -41,7 +41,9 @@ cc_library( srcs = ["padding.cc"], hdrs = ["padding.h"], deps = [ + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 387253617e4f37a1561d4659eb796a181f0b5bee..92cd8e729d659c4ff24c156d89f29275848c3cee 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -206,6 +206,7 @@ StatusOr> Client::Execute( *request.mutable_execution_options() = *execution_options; } for (GlobalData* argument : arguments) { + CHECK(argument != nullptr) << "Argument pointers must not be null."; *request.add_arguments() = argument->handle(); } @@ -241,9 +242,6 @@ StatusOr>> Client::ExecuteParallel( for (GlobalData* argument : computation.arguments) { *single_request.add_arguments() = argument->handle(); } - if (computation.device_handle != nullptr) { - *single_request.mutable_device_handle() = *computation.device_handle; - } *single_request.mutable_execution_options() = computation.execution_options; *request.add_requests() = single_request; } diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index e72816a6217afd6a827642bbe3aa205409ef5718..a716159f9e74041c4823ad20b46fa94c2d7b9d8c 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -45,6 +45,10 @@ class Client { // * If execution_options is not nullptr, these options are passed to the // service to affect how it compiles our computation. (The pointer does not // need to live beyond this call.) + // * If execution_options.device_handles is not empty, the computation is + // executed on the devices associated with the handles by partitioning the + // computation based on the attached sharding attributes. Otherwise, a + // device is chosen by the service. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. StatusOr> Execute( @@ -54,12 +58,13 @@ class Client { ExecutionProfile* execution_profile = nullptr); // A struct to represent a computation instance to be executed. - // * If device_handle is not nullptr, the computation is executed on a device - // associated with the handle. Otherwise, a device is chosen by the service. + // * If execution_options.device_handles is not empty, the computation is + // executed on the devices associated with the handles by partitioning the + // computation based on the attached sharding attributes. Otherwise, a + // device is chosen by the service. struct ComputationInstance { const Computation& computation; std::vector arguments; - const DeviceHandle* device_handle; ExecutionOptions execution_options; ExecutionProfile* execution_profile; }; diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index a80412e951aaea74c36b3fd372c24fe6ec8f9fb8..edf5a1822cb2520cada5e0a8327a37e0f9c5e759 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -489,6 +489,16 @@ ComputationDataHandle ComputationBuilder::Collapse( } std::unique_ptr original_shape = shape_or_status.ConsumeValueOrDie(); + VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); + VLOG(3) << "dims to collapse: " + << tensorflow::str_util::Join(dims_to_collapse, ","); + + if (dims_to_collapse.size() <= 1) { + // Not collapsing anything, trivially we can return the operand versus + // enqueueing a trivial reshape. + return operand; + } + std::vector new_sizes; for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) { @@ -498,6 +508,9 @@ ComputationDataHandle ComputationBuilder::Collapse( } } + VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") + << "]"; + return Reshape(operand, new_sizes); } @@ -650,7 +663,7 @@ bool ComputationBuilder::VerifyConvolution( return false; } int num_dims = ShapeUtil::Rank(lhs_shape); - if (num_dims < 3) { + if (num_dims < 2) { NoteError(InvalidArgument( "Convolution expects argument arrays with >= 3 dimensions. " "Got: %s and %s", @@ -900,6 +913,17 @@ ComputationDataHandle ComputationBuilder::CustomCall( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::Complex( + const ComputationDataHandle& real, const ComputationDataHandle& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Conj( + const ComputationDataHandle& operand) { + return Complex(Real(operand), Neg(Imag(operand))); +} + ComputationDataHandle ComputationBuilder::Add( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { @@ -942,21 +966,39 @@ ComputationDataHandle ComputationBuilder::Min( return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions); } -ComputationDataHandle ComputationBuilder::LogicalAnd( +ComputationDataHandle ComputationBuilder::And( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LOGICAL_AND, lhs, rhs, broadcast_dimensions); + return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions); } -ComputationDataHandle ComputationBuilder::LogicalOr( +ComputationDataHandle ComputationBuilder::Or( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LOGICAL_OR, lhs, rhs, broadcast_dimensions); + return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions); } -ComputationDataHandle ComputationBuilder::LogicalNot( +ComputationDataHandle ComputationBuilder::Not( const ComputationDataHandle& operand) { - return UnaryOp(UNOP_LOGICAL_NOT, operand); + return UnaryOp(UNOP_NOT, operand); +} + +ComputationDataHandle ComputationBuilder::ShiftLeft( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::ShiftRightArithmetic( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::ShiftRightLogical( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions); } ComputationDataHandle ComputationBuilder::Abs( @@ -964,6 +1006,12 @@ ComputationDataHandle ComputationBuilder::Abs( return UnaryOp(UNOP_ABS, operand); } +ComputationDataHandle ComputationBuilder::Atan2( + const ComputationDataHandle& y, const ComputationDataHandle& x, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions); +} + ComputationDataHandle ComputationBuilder::Exp( const ComputationDataHandle& operand) { return UnaryOp(UNOP_EXP, operand); @@ -1009,6 +1057,16 @@ ComputationDataHandle ComputationBuilder::Tanh( return UnaryOp(UNOP_TANH, operand); } +ComputationDataHandle ComputationBuilder::Real( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_REAL, operand); +} + +ComputationDataHandle ComputationBuilder::Imag( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_IMAG, operand); +} + ComputationDataHandle ComputationBuilder::IsFinite( const ComputationDataHandle& operand) { return UnaryOp(UNOP_IS_FINITE, operand); @@ -1433,10 +1491,20 @@ ComputationDataHandle ComputationBuilder::ReduceWindow( return ComputationDataHandle(); } - return ReduceWindowWithGeneralPadding( - operand, init_value, computation, window_dimensions, window_strides, + Status padding_valid = + ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()), + window_dimensions, window_strides); + if (!padding_valid.ok()) { + first_error_ = padding_valid; + return ComputationDataHandle(); + } + + std::vector> padding_values = MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding)); + window_dimensions, window_strides, padding); + return ReduceWindowWithGeneralPadding(operand, init_value, computation, + window_dimensions, window_strides, + padding_values); } ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( @@ -1739,8 +1807,10 @@ void ComputationBuilder::SetDeviceAssignment( /* static */ ConvolutionDimensionNumbers ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(kConvBatchDimension); - dimension_numbers.set_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_input_batch_dimension(kConvBatchDimension); + dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_output_batch_dimension(kConvBatchDimension); + dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); dimension_numbers.set_kernel_output_feature_dimension( kConvKernelOutputDimension); dimension_numbers.set_kernel_input_feature_dimension( @@ -1754,15 +1824,17 @@ ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { /* static */ StatusOr ComputationBuilder::CreateConvDimensionNumbers( - int64 batch, int64 feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 output_batch, + int64 output_feature, int64 first_spatial, int64 second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial) { - if (std::set({batch, feature, first_spatial, second_spatial}).size() != - 4) { + if (std::set( + {input_batch, input_feature, first_spatial, second_spatial}) + .size() != 4) { return FailedPrecondition( "dimension numbers for the input are not unique: (%lld, %lld, %lld, " "%lld)", - batch, feature, first_spatial, second_spatial); + input_batch, input_feature, first_spatial, second_spatial); } if (std::set({kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial}) @@ -1773,9 +1845,19 @@ ComputationBuilder::CreateConvDimensionNumbers( kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial); } + if (std::set( + {output_batch, output_feature, first_spatial, second_spatial}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the output are not unique: (%lld, %lld, %lld, " + "%lld)", + output_batch, output_feature, first_spatial, second_spatial); + } ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(batch); - dimension_numbers.set_feature_dimension(feature); + dimension_numbers.set_input_batch_dimension(input_batch); + dimension_numbers.set_input_feature_dimension(input_feature); + dimension_numbers.set_output_batch_dimension(output_batch); + dimension_numbers.set_output_feature_dimension(output_feature); dimension_numbers.add_spatial_dimensions(first_spatial); dimension_numbers.add_spatial_dimensions(second_spatial); dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 73972c1290f5f2dda13cece9df09782c4ab0b709..d2f0c7cff00bdf787ddc5e2b1c015eb8d1724df0 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -138,6 +138,11 @@ class ComputationBuilder { ComputationDataHandle ConstantR2( std::initializer_list> values); template + ComputationDataHandle ConstantFromArrayWithLayout( + const Array& values, const Layout& layout); + template + ComputationDataHandle ConstantFromArray(const Array& values); + template ComputationDataHandle ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout); template @@ -201,6 +206,16 @@ class ComputationBuilder { // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must // be a consecutive, in-order subsequence of the operand dimensions. // + // Note that collapsing a single dimension does nothing: + // + // {256} collapsing {0} => {256} + // {1} collapsing {0} => {1} + // + // Collapsing multiple dimensions produces a single result dimension: + // + // {256, 2} collapsing {0,1} => {512} + // {256, 2, 3} collapsing {0,1} => {512, 3} + // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. ComputationDataHandle Collapse(const ComputationDataHandle& operand, @@ -344,7 +359,8 @@ class ComputationBuilder { // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an // error if either the input or the weight dimension numbers have conflicts. static StatusOr CreateConvDimensionNumbers( - int64 batch, int64 feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 output_batch, + int64 output_feature, int64 first_spatial, int64 second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial); @@ -415,6 +431,14 @@ class ComputationBuilder { // of the operands is a scalar, or an explicit broadcast dimension is given // (see g3doc for more details). + // Enqueues a complex compose instruction onto the computation. + ComputationDataHandle Complex( + const ComputationDataHandle& real, const ComputationDataHandle& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a complex conjugate instruction onto the computation. + ComputationDataHandle Conj(const ComputationDataHandle& operand); + // Enqueues an add instruction onto the computation. ComputationDataHandle Add( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, @@ -451,15 +475,25 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice broadcast_dimensions = {}); // Element-wise logical operators - ComputationDataHandle LogicalAnd( + ComputationDataHandle And( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle LogicalOr( + ComputationDataHandle Or( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle LogicalNot(const ComputationDataHandle& lhs); + ComputationDataHandle Not(const ComputationDataHandle& operand); + + ComputationDataHandle ShiftLeft( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + ComputationDataHandle ShiftRightArithmetic( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + ComputationDataHandle ShiftRightLogical( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. @@ -516,6 +550,11 @@ class ComputationBuilder { // Enqueues an abs instruction onto the computation. ComputationDataHandle Abs(const ComputationDataHandle& operand); + // Enqueues a atan2 instruction onto the computation. + ComputationDataHandle Atan2( + const ComputationDataHandle& y, const ComputationDataHandle& x, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + // Enqueues an exp instruction onto the computation. ComputationDataHandle Exp(const ComputationDataHandle& operand); @@ -544,6 +583,12 @@ class ComputationBuilder { // Enqueues a tanh instruction onto the computation. ComputationDataHandle Tanh(const ComputationDataHandle& operand); + // Enqueues a real-part instruction onto the computation. + ComputationDataHandle Real(const ComputationDataHandle& operand); + + // Enqueues an imaginary-part instruction onto the computation. + ComputationDataHandle Imag(const ComputationDataHandle& operand); + // Enqueues a float32 sqrt instruction onto the computation. // (float32 is specified as there is an implicit float32 0.5f constant // exponent). @@ -889,48 +934,54 @@ ComputationDataHandle ComputationBuilder::ConstantR2( } template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { +ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( + const Array& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR2FromArray2DWithLayout(values, layout); + literal->PopulateFromArrayWithLayout(values, layout); }); } +template +ComputationDataHandle ComputationBuilder::ConstantFromArray( + const Array& values) { + return ConstantOp( + [&values](Literal* literal) { literal->PopulateFromArray(values); }); +} + +template +ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return ConstantFromArrayWithLayout(values, layout); +} + template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( const Array2D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR2FromArray2D(values); }); + return ConstantFromArray(values); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR3FromArray3DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( const Array3D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR3FromArray3D(values); }); + return ConstantFromArray(values); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR4FromArray4DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( const Array4D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR4FromArray4D(values); }); + return ConstantFromArray(values); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 969b0eee1d195a36728f16a598add4b3b850ed60..24048a1e5a782661ba577ba50e3b5b2914f17c0a 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -89,24 +89,24 @@ Computation CreateScalarMinComputation(PrimitiveType type, const ComputationDataHandle& rhs) { return b->Min(lhs, rhs); }); } -Computation CreateScalarLogicalAndComputation(ComputationBuilder* builder) { +Computation CreateScalarAndComputation(ComputationBuilder* builder) { return CreateScalarComputation( - "logical_and", PRED, builder, + "and", PRED, builder, [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->LogicalAnd(lhs, rhs); }); + const ComputationDataHandle& rhs) { return b->And(lhs, rhs); }); } -Computation CreateScalarLogicalOrComputation(ComputationBuilder* builder) { +Computation CreateScalarOrComputation(ComputationBuilder* builder) { return CreateScalarComputation( - "logical_or", PRED, builder, + "or", PRED, builder, [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->LogicalOr(lhs, rhs); }); + const ComputationDataHandle& rhs) { return b->Or(lhs, rhs); }); } StatusOr Any(const ComputationDataHandle& predicates, ComputationBuilder* builder) { auto f = builder->ConstantR0(false); - Computation logical_or = CreateScalarLogicalOrComputation(builder); + Computation logical_or = CreateScalarOrComputation(builder); TF_ASSIGN_OR_RETURN(std::unique_ptr predicates_shape, builder->GetShape(predicates)); std::vector all_dimensions(ShapeUtil::Rank(*predicates_shape)); diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index f43d35fe4a52016d4054af28835d6b66a35217d4..ae89784bc227d837cf15f0a89687dd00dccc2745 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -45,10 +45,10 @@ Computation CreateScalarMinComputation(PrimitiveType type, ComputationBuilder* builder); // Creates a scalar logical AND computation and returns it. -Computation CreateScalarLogicalAndComputation(ComputationBuilder* builder); +Computation CreateScalarAndComputation(ComputationBuilder* builder); // Creates a scalar logical OR computation and returns it. -Computation CreateScalarLogicalOrComputation(ComputationBuilder* builder); +Computation CreateScalarOrComputation(ComputationBuilder* builder); // Returns whether any predicate in "predicates" is set. // diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 482d53cf330f152f496b77233714f93991fef6f0..e6645e4941bd04c658b67117bb689f6fdef7dfc1 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -79,6 +79,24 @@ StatusOr> MakeFakeLiteral(const Shape& shape) { })); break; } + case S64: { + std::uniform_int_distribution generator( + std::numeric_limits::lowest(), + std::numeric_limits::max()); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); + break; + } + case PRED: { + std::uniform_int_distribution generator(0, 1); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); + break; + } default: return Unimplemented("Unsupported type for fake literal generation: %s", ShapeUtil::HumanString(shape).c_str()); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index c885b815ebef60bbabfdbd97642d0be9bbbf49e8..15c744ecd349e91dc703bec5708d78a896f132c3 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -175,10 +175,15 @@ StatusOr> LocalExecutable::Run( TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_)); ExecutableRunOptions actual_options = options; + + Backend::StreamPtr stream; if (options.stream() == nullptr) { + // NB! The lifetime of `stream` needs to match the lifetime of + // `actual_options` (otherwise we will end up using a returned stream in + // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if" + // scope. TF_ASSIGN_OR_RETURN( - Backend::StreamPtr stream, - BorrowStreamForDevice(options.device_ordinal(), backend_)); + stream, BorrowStreamForDevice(options.device_ordinal(), backend_)); actual_options.set_stream(stream.get()); } if (options.allocator() == nullptr) { diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 0b18d8946a2e62a810f875b4d79fd5375e787487..6a9cf466ac0a43ce214ef0e6aae9e6295f137b0f 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -17,17 +17,34 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" namespace xla { +Status ValidatePaddingValues( + tensorflow::gtl::ArraySlice input_dimensions, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides) { + bool ok = input_dimensions.size() == window_dimensions.size() && + input_dimensions.size() == window_strides.size(); + if (!ok) { + return InvalidArgument( + "Want input dimensions size %zu = window dimensions size %zu = window " + "strides size %zu", + input_dimensions.size(), window_dimensions.size(), + window_strides.size()); + } + return Status::OK(); +} + std::vector> MakePadding( tensorflow::gtl::ArraySlice input_dimensions, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - CHECK_EQ(input_dimensions.size(), window_dimensions.size()); - CHECK_EQ(input_dimensions.size(), window_strides.size()); + TF_CHECK_OK(ValidatePaddingValues(input_dimensions, window_dimensions, + window_strides)); std::vector> low_high_padding; switch (padding) { case Padding::kValid: diff --git a/tensorflow/compiler/xla/client/padding.h b/tensorflow/compiler/xla/client/padding.h index dce2d87e8da8b3d9fd138a712c459ea0081372e0..e23b0b3a90a091bf80973525810793c3eda4a036 100644 --- a/tensorflow/compiler/xla/client/padding.h +++ b/tensorflow/compiler/xla/client/padding.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -37,6 +38,14 @@ enum class Padding { kValid, }; +// Validates that the slices are acceptable for determining padding -- this can +// be used to check the preconditions of MakePadding below to produce an error +// message that can be returned to the user. +Status ValidatePaddingValues( + tensorflow::gtl::ArraySlice input_dimensions, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides); + // Returns the padding needed for the base area, given the base area dimensions, // window dimensions, strides, and the type of padding. // @@ -51,7 +60,7 @@ enum class Padding { std::vector> MakePadding( tensorflow::gtl::ArraySlice input_dimensions, tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice strides, Padding padding); + tensorflow::gtl::ArraySlice window_strides, Padding padding); } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 011fc3c194e0eb9ebd6b9e42571deddaf25c09ff..5c2cc2a7a99cc51ded3d98c9dd5903e4b3078548 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return CreateDefaultLayoutForRank(shape.dimensions_size()); } +/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) { + return CreateDefaultLayoutForRank(rank); +} + /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() { return CreateDefaultLayoutForRank(2); } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 5de0a653f66688ac75fc377c18ff93012314abdd..bc42e222292933be35e82d1fe50802e8830d16b3 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -40,6 +40,7 @@ class LayoutUtil { static Layout GetDefaultLayoutForShape(const Shape& shape); // Helper functions that create default layouts for various ranks. + static Layout GetDefaultLayoutForRank(int64 rank); static Layout GetDefaultLayoutForR2(); static Layout GetDefaultLayoutForR3(); static Layout GetDefaultLayoutForR4(); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 8892bfbe929d168c602af24cfbb507256dc05328..f2cdd9669c727bb778fce495ede0faaf2d9a923d 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -206,9 +206,9 @@ void AllocateFlags() { flag_values->xla_gpu_disable_multi_streaming(), "If true, multi-streaming in the GPU backend is disabled."), tensorflow::Flag( - "xla_dump_debug_json_to", - flag_values->mutable_xla_dump_debug_json_to(), - "Dump compilation artifacts as JSON into this directory."), + "xla_dump_hlo_proto_to", + flag_values->mutable_xla_dump_hlo_proto_to(), + "Dump compilation artifacts as proto binary into this directory."), tensorflow::Flag( "xla_test_all_output_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 79e40c12625c41b7234542381d0ca528be7eaed4..8fc8644a60ef62d7ba5e7f0cc11253742395f09b 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -173,6 +173,8 @@ Status Literal::Copy(const Literal& src_literal, return CopyRange(src_literal, src_base, dest_base, copy_size); case F64: return CopyRange(src_literal, src_base, dest_base, copy_size); + case C64: + return CopyRange(src_literal, src_base, dest_base, copy_size); case PRED: return CopyRange(src_literal, src_base, dest_base, copy_size); default: @@ -202,6 +204,8 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(0); case F64: return *Literal::CreateR0(0); + case C64: + return *Literal::CreateR0(0); case PRED: return *Literal::CreateR0(false); case S16: @@ -234,6 +238,8 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(1); case F64: return *Literal::CreateR0(1); + case C64: + return *Literal::CreateR0(1); case PRED: return *Literal::CreateR0(true); case S16: @@ -269,6 +275,8 @@ Status Literal::Copy(const Literal& src_literal, case F64: return *Literal::CreateR0( -std::numeric_limits::infinity()); + case C64: + LOG(FATAL) << "C64 element type has no minimum value"; case PRED: return *Literal::CreateR0(false); case S16: @@ -522,6 +530,10 @@ string Literal::GetAsString( return tensorflow::strings::StrCat(Get(multi_index)); case F64: return tensorflow::strings::StrCat(Get(multi_index)); + case C64: { + complex64 c = Get(multi_index); + return tensorflow::strings::StrCat("(", c.real(), ", ", c.imag(), ")"); + } case F16: return tensorflow::strings::StrCat(Get(multi_index)); default: @@ -716,6 +728,8 @@ void* Literal::MutableInternalData() { return reinterpret_cast(f32s_.data()); case F64: return reinterpret_cast(f64s_.data()); + case C64: + return reinterpret_cast(c64s_.data()); case F16: return reinterpret_cast(f16s_.data()); default: @@ -754,6 +768,9 @@ void Literal::Reserve(int64 num_elements) { case F64: Resize(num_elements, 0); break; + case C64: + Resize(num_elements, 0); + break; case F16: Resize(num_elements, static_cast(0.0f)); break; @@ -790,6 +807,9 @@ tensorflow::Status Literal::ValidateLiteral() const { case F64: actual = f64s_size(); break; + case C64: + actual = c64s_size(); + break; case F16: actual = f16s().size() / sizeof(half); break; @@ -843,6 +863,26 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { return result_literal; } +template +std::unique_ptr ConvertToC64(const Literal& src_literal) { + auto result_literal = MakeUnique(); + Shape* result_shape = result_literal->mutable_shape(); + *result_shape = src_literal.shape(); + result_shape->set_element_type(C64); + result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); + using NativeSrcT = + typename primitive_util::PrimitiveTypeToNative::type; + tensorflow::gtl::ArraySlice src_data = + src_literal.GetArraySlice(); + tensorflow::gtl::MutableArraySlice dest_data = + result_literal->GetMutableArraySlice(); + int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape()); + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = complex64(static_cast(src_data[i]), 0); + } + return result_literal; +} + template std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); @@ -870,6 +910,8 @@ StatusOr> ConvertIfDestTypeMatches( CONVERT_IF_TYPES_MATCH(F32) CONVERT_IF_TYPES_MATCH(F64) #undef CONVERT_IF_TYPES_MATCH + case C64: + return ConvertToC64(src_literal); // Other types are not yet supported. default: return InvalidArgument( @@ -966,6 +1008,8 @@ bool Literal::operator==(const Literal& other) const { return EqualElements(*this, other, 0, &multi_index); case F16: return EqualElements(*this, other, 0, &multi_index); + case C64: + return EqualElements(*this, other, 0, &multi_index); default: LOG(FATAL) << "Unimplemented: Literal::Equal for type " << PrimitiveType_Name(shape().element_type()); @@ -1065,6 +1109,12 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { values->size()); } +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_c64s(); + return {values->data(), values->size()}; +} + template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { // TODO - there is an endianess problem here. fix it, or wait for uint16 @@ -1144,6 +1194,13 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { f16s().size() / sizeof(half)); } +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() + const { + CHECK_EQ(shape().element_type(), C64); + return c64s(); +} + template static bool AllElementsEqualValue(const Literal& literal, NativeT value) { for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { @@ -1211,6 +1268,15 @@ bool Literal::IsAllFloat(float value) const { } } +bool Literal::IsAllComplex(complex64 value) const { + switch (shape().element_type()) { + case C64: + return AllElementsEqualValue(*this, value); + default: + return false; + } +} + bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { switch (shape().element_type()) { case U8: @@ -1229,6 +1295,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { return Get(indices) == 0.0f; case F64: return Get(indices) == 0.0; + case C64: + return Get(indices) == complex64(0.0f, 0.0f); case F16: return Get(indices) == static_cast(0.0f); case PRED: @@ -1298,12 +1366,27 @@ void Literal::Resize(int64 num_elements, half value) { mutable_f16s()->resize(num_elements, value); } +template <> +void Literal::Resize(int64 num_elements, complex64 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_c64s()->resize(num_elements, value); +} + template -static void CopyToRepeatedField(RepeatedFieldT* dest, - const std::vector& src) { +void CopyToRepeatedField(RepeatedFieldT* dest, + const std::vector& src) { *dest = RepeatedFieldT(src.begin(), src.end()); } +template <> +void CopyToRepeatedField, complex64>( + tensorflow::protobuf::RepeatedField* dest, + const std::vector& src) { + *dest = tensorflow::protobuf::RepeatedField( + reinterpret_cast(src.data()), + reinterpret_cast(src.data()) + src.size() * 2); +} + LiteralProto Literal::ToProto() const { LiteralProto proto; proto.Clear(); @@ -1338,6 +1421,9 @@ LiteralProto Literal::ToProto() const { case F64: CopyToRepeatedField(proto.mutable_f64s(), f64s()); break; + case C64: + CopyToRepeatedField(proto.mutable_c64s(), c64s()); + break; case TUPLE: for (const auto& tuple : tuple_literals()) { *proto.add_tuple_literals() = tuple.ToProto(); @@ -1351,11 +1437,21 @@ LiteralProto Literal::ToProto() const { } template -static void CopyFromRepeatedField(std::vector* dest, - const RepeatedFieldT& src) { +void CopyFromRepeatedField(std::vector* dest, + const RepeatedFieldT& src) { *dest = std::vector(src.begin(), src.end()); } +template <> +void CopyFromRepeatedField, + complex64>( + std::vector* dest, + const tensorflow::protobuf::RepeatedField& src) { + *dest = std::vector( + reinterpret_cast(src.data()), + reinterpret_cast(src.data()) + src.size() / 2); +} + void Literal::CopyFromProto(const LiteralProto& literal_proto) { if (!literal_proto.has_shape()) { return; @@ -1394,6 +1490,9 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { case F64: CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s()); break; + case C64: + CopyFromRepeatedField(mutable_c64s(), literal_proto.c64s()); + break; case TUPLE: for (const auto& proto : literal_proto.tuple_literals()) { mutable_tuple_literals()->push_back(Literal(proto)); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e8cee732d4cf5629c1e2b9c98d1f1bbe1e29a122..a1e288829f22835f94c6e3c041796f84d995211c 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -159,6 +159,10 @@ class Literal { const std::vector& f64s() const { return f64s_; } std::vector* mutable_f64s() { return &f64s_; } + int c64s_size() const { return c64s().size(); } + const std::vector& c64s() const { return c64s_; } + std::vector* mutable_c64s() { return &c64s_; } + int tuple_literals_size() const { return tuple_literals().size(); } const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } Literal* add_tuple_literals() { @@ -334,6 +338,11 @@ class Literal { // WithLayout use the default XLA layout for the literal's linear // representation in memory. template + static std::unique_ptr CreateFromArray(const Array& values); + template + static std::unique_ptr CreateFromArrayWithLayout( + const Array& values, const Layout& layout); + template static std::unique_ptr CreateR2FromArray2D( const Array2D& values); template @@ -481,6 +490,11 @@ class Literal { std::initializer_list> values, const Layout& layout); template + void PopulateFromArray(const Array& values); + template + void PopulateFromArrayWithLayout(const Array& values, + const Layout& layout); + template void PopulateR2FromArray2D(const Array2D& values); template void PopulateR2FromArray2DWithLayout(const Array2D& values, @@ -550,6 +564,17 @@ class Literal { // e.g. -0.5. bool IsAllFloat(float value) const; + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + bool IsAllComplex(complex64 value) const; + // Returns whether this literal is zero at the specified index. This literal // must be an array. bool IsZero(tensorflow::gtl::ArraySlice indices) const; @@ -600,6 +625,7 @@ class Literal { std::vector f16s_; std::vector f32s_; std::vector f64s_; + std::vector c64s_; std::vector tuple_literals_; }; @@ -648,6 +674,10 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() + const; + template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); @@ -684,6 +714,9 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + template <> void Literal::Resize(int64 num_elements, bool value); @@ -714,6 +747,9 @@ void Literal::Resize(int64 num_elements, double value); template <> void Literal::Resize(int64 num_elements, half value); +template <> +void Literal::Resize(int64 num_elements, complex64 value); + template /* static */ std::unique_ptr Literal::CreateR0(NativeT value) { auto literal = MakeUnique(); @@ -816,33 +852,42 @@ template } template -/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateFromArrayWithLayout( + const Array& values, const Layout& layout) { auto literal = MakeUnique(); - literal->PopulateR2FromArray2DWithLayout(values, layout); + literal->PopulateFromArrayWithLayout(values, layout); return literal; } +template +/* static */ std::unique_ptr Literal::CreateFromArray( + const Array& values) { + return CreateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + +template +/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return CreateFromArrayWithLayout(values, layout); +} + template /* static */ std::unique_ptr Literal::CreateR2FromArray2D( const Array2D& values) { - return CreateR2FromArray2DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR2()); + return CreateFromArray(values); } template /* static */ std::unique_ptr Literal::CreateR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR3FromArray3DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template /* static */ std::unique_ptr Literal::CreateR3FromArray3D( const Array3D& values) { - return CreateR3FromArray3DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR3()); + return CreateFromArray(values); } template @@ -901,16 +946,13 @@ template template /* static */ std::unique_ptr Literal::CreateR4FromArray4D( const Array4D& values) { - return CreateR4FromArray4DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR4()); + return CreateFromArray(values); } template /* static */ std::unique_ptr Literal::CreateR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR4FromArray4DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template @@ -1070,82 +1112,53 @@ void Literal::PopulateR2( } template -void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +void Literal::PopulateFromArrayWithLayout(const Array& values, + const Layout& layout) { *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); + primitive_util::NativeToPrimitiveType(), values.dimensions(), + AsInt64Slice(layout.minor_to_major())); + Reserve(values.num_elements()); + values.Each([this](tensorflow::gtl::ArraySlice indices, + NativeT value) { this->Set(indices, value); }); +} - const int64 dim1_size = values.width(); - const int64 dim0_size = values.height(); - CHECK_EQ(dim0_size, shape().dimensions(0)); - CHECK_EQ(dim1_size, shape().dimensions(1)); - Reserve(dim1_size * dim0_size); - for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { - for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { - Set({dim0, dim1}, values(dim0, dim1)); - } - } +template +void Literal::PopulateFromArray(const Array& values) { + PopulateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + +template +void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout) { + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR2FromArray2D(const Array2D& values) { - PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); + PopulateFromArray(values); } template void Literal::PopulateR3FromArray3DWithLayout(const Array3D& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.n1(), values.n2(), values.n3()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - Reserve(values.n1() * values.n2() * values.n3()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - Set({dim0, dim1, dim2}, values(dim0, dim1, dim2)); - } - } - } + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR3FromArray3D(const Array3D& values) { - PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); + PopulateFromArray(values); } template void Literal::PopulateR4FromArray4DWithLayout(const Array4D& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.planes(), values.depth(), values.height(), values.width()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - CHECK_EQ(values.n4(), shape().dimensions(3)); - Reserve(values.n1() * values.n2() * values.n3() * values.n4()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) { - Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3)); - } - } - } - } + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR4FromArray4D(const Array4D& values) { - PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); + PopulateFromArray(values); } template diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index e7dedd08218d8a17c5e332e5cda7bedcc26f6703..a9af4849e2124fd47ae42cc06ac8cc5ca5a22cb7 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -107,6 +107,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f16_lit = Literal::CreateR0(static_cast(0.5f)); ASSERT_EQ("0.5", f16_lit->ToString()); + + auto c64_lit = Literal::CreateR0({3.14f, 2.78f}); + ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -331,6 +334,19 @@ TEST_F(LiteralUtilTest, TupleEquality) { EXPECT_NE(*tuple1, *different_tuple); } +TEST_F(LiteralUtilTest, C64Equality) { + // Test equality with tuples. + auto vector = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + + // Tuple with the same elements. One element is shared with the original + // tuple, the other is a clone of the element in the original tuple. + auto vector_clone = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + EXPECT_EQ(*vector, *vector_clone); + + auto vector_reversed = Literal::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); + EXPECT_NE(*vector, *vector_reversed); +} + TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = Literal::CreateR0(0.0); auto element2 = Literal::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); @@ -381,6 +397,9 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); + complex64 c8_9 = {8, 9}; + EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); + auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(Literal::CreateR2( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) @@ -411,6 +430,25 @@ TEST_F(LiteralUtilTest, IsAllFloat) { Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); } +TEST_F(LiteralUtilTest, IsAllComplex) { + // IsAllComplex always returns false when the literal is not complex. + EXPECT_FALSE(Literal::CreateR0(false)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + + complex64 c8_9 = {8, 9}; + complex64 c7_9 = {7, 9}; + EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}}) + ->IsAllComplex({8.0f, 9.0f})); + EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}}) + ->IsAllComplex({8.0f, 9.0f})); + EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c7_9}}) + ->IsAllComplex({8.0f, 9.0f})); +} + TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = Literal::CreateR0(0.0f); auto scalar_one = Literal::CreateR0(1.0f); @@ -422,12 +460,17 @@ TEST_F(LiteralUtilTest, IsZero) { EXPECT_TRUE(array->IsZero({0, 2})); EXPECT_TRUE(array->IsZero({1, 1})); EXPECT_FALSE(array->IsZero({1, 2})); + + auto complex_zero = Literal::CreateR0(0.0f); + auto complex_nonzero = Literal::CreateR0(0.5f); + EXPECT_TRUE(complex_zero->IsZero({})); + EXPECT_FALSE(complex_nonzero->IsZero({})); } template class LiteralUtilTestTemplated : public ::testing::Test {}; -using TestedTypes = ::testing::Types; +using TestedTypes = ::testing::Types; TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { @@ -626,13 +669,28 @@ TEST_F(LiteralUtilTest, PopulateR1S64) { EXPECT_EQ(output, *expected); } -TEST_F(LiteralUtilTest, PopulateR2U64) { +TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output; output.PopulateR1({{77, 88}}); auto expected = Literal::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateR1C64) { + Literal output; + output.PopulateR1({{77, 88}}); + auto expected = Literal::CreateR1({{77, 88}}); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateR2C64) { + Literal output; + output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); + auto expected = + Literal::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; output.PopulateWithValue(2.5f, {}); @@ -654,6 +712,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { + Literal output; + output.PopulateWithValue({4, 2}, {2, 2}); + auto expected = + Literal::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output; half h(0.25f); @@ -919,6 +985,11 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, }}, layout_r4_dim0major_); + auto c64 = Literal::CreateR4WithLayout({{ + {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, + {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, + {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, + }}, layout_r4_dim0major_); // clang-format on std::unique_ptr conv; @@ -961,12 +1032,22 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = u32->Convert(F16).ConsumeValueOrDie(); EXPECT_EQ(*conv, *f16); + conv = s32->Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *c64); + + conv = f16->Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *c64); + EXPECT_EQ(s32->Convert(TUPLE).status().code(), tensorflow::error::INVALID_ARGUMENT); EXPECT_EQ(s32->Convert(S16).status().code(), tensorflow::error::INVALID_ARGUMENT); EXPECT_EQ(s32->Convert(U16).status().code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(c64->Convert(F32).status().code(), + tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(c64->Convert(S32).status().code(), + tensorflow::error::INVALID_ARGUMENT); } TEST_F(LiteralUtilTest, CopyFromProto_Bool) { diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index e4e37177a2d74e6da20300f1439942a146ad8d49..2113b5e06f3eb0169be50c0ee731a903c0eece9d 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -83,10 +83,17 @@ PrimitiveType NativeToPrimitiveType() { return F16; } +template <> +PrimitiveType NativeToPrimitiveType() { + return C64; +} + bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64; } +bool IsComplexType(PrimitiveType type) { return type == C64; } + bool IsSignedIntegralType(PrimitiveType type) { return type == S8 || type == S16 || type == S32 || type == S64; } @@ -121,6 +128,7 @@ int BitWidth(PrimitiveType type) { case U64: case S64: case F64: + case C64: return 64; case TUPLE: @@ -134,5 +142,15 @@ int BitWidth(PrimitiveType type) { } } +PrimitiveType ComplexComponentType(PrimitiveType complex_type) { + switch (complex_type) { + case C64: + return F32; + default: + LOG(FATAL) << "Primitive type is not complex: " + << PrimitiveType_Name(complex_type); + } +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 162a11c7d2966346979b98c804917203f82c806c..a49c8b86fcfe156ea3733ce05c0fb7337cf60dce 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -78,8 +78,14 @@ PrimitiveType NativeToPrimitiveType(); template <> PrimitiveType NativeToPrimitiveType(); +// Complex +template <> +PrimitiveType NativeToPrimitiveType(); + bool IsFloatingPointType(PrimitiveType type); +bool IsComplexType(PrimitiveType type); + bool IsSignedIntegralType(PrimitiveType type); bool IsUnsignedIntegralType(PrimitiveType type); @@ -89,6 +95,10 @@ bool IsIntegralType(PrimitiveType type); // Returns the number of bits in the representation for a given type. int BitWidth(PrimitiveType type); +// Returns the real, imag component type underlying the given complex type. +// LOG(FATAL)'s if complex_type is not complex. +PrimitiveType ComplexComponentType(PrimitiveType complex_type); + // Returns the native type (eg, float) corresponding to the given template // parameter XLA primitive type (eg, F32). template @@ -157,6 +167,11 @@ struct PrimitiveTypeToNative { using type = half; }; +// Complex +template <> +struct PrimitiveTypeToNative { + using type = complex64; +}; } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index cdc4139cd69c3d6eb4afc2e5d25f9446ffad0a11..787725e884c810fd724ab88ad7d4beaf3e0a6cc7 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -37,34 +37,27 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, return (serialized1 == serialized2); } -StatusOr ToJson(const tensorflow::protobuf::Message& message) { - string json_output; - tensorflow::protobuf::util::JsonPrintOptions json_options; - json_options.add_whitespace = true; - json_options.always_print_primitive_fields = true; - auto status = tensorflow::protobuf::util::MessageToJsonString( - message, &json_output, json_options); - if (!status.ok()) { - return InternalError("MessageToJsonString failed: %s", - status.error_message().data()); - } - return json_output; -} - -Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, - const string& directory, const string& file_name) { - TF_ASSIGN_OR_RETURN(const string json_output, ToJson(message)); +namespace { - tensorflow::Env* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); - string safe_file_name = file_name + ".json"; +string SanitizeFilename(const string& file_name) { + string safe_file_name = file_name; for (char& c : safe_file_name) { if (c == '/' || c == '\\') { c = '_'; } } + return safe_file_name; +} + +} // namespace + +Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, + const string& directory, const string& file_name) { + tensorflow::Env* env = tensorflow::Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + string safe_file_name = SanitizeFileName(file_name) + ".pb"; const string path = tensorflow::io::JoinPath(directory, safe_file_name); - return tensorflow::WriteStringToFile(env, path, json_output); + return tensorflow::WriteBinaryProto(env, path, message); } } // namespace protobuf_util diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 1a895c3585902e8fbc0d20475c2817ef4caa4c71..3667621367c7639c40ff17aee7b77305d4d34e33 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -32,15 +32,12 @@ namespace protobuf_util { extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, const tensorflow::protobuf::Message& m2); -// Returns 'message' as a JSON string. -StatusOr ToJson(const tensorflow::protobuf::Message& message); - -// Converts 'message' to JSON, and dumps it to the path formed by joining -// 'directory/file_name.json'. The 'directory' is recursively created if it -// doesn't already exist, and the 'file_name' is sanitized by replacing illegal -// characters with underscore '_'. -Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, - const string& directory, const string& file_name); +// Writes the given message in binary proto to the path formed by joining +// 'directory/file_name.pb'. The 'directory' is recursively created if it +// doesn't already exist, and the 'file_name' is sanitized by replacing +// illegal characters with underscore '_'. +Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, + const string& directory, const string& file_name); } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 35b5e8cd52ab0ec21a4bd2df3e9fa0538ae60816..eb6a71242ffa1499876b90f14f8a60ffdbdd069c 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -322,8 +322,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { // Set the convolution dimension numbers. ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(2); - dimension_numbers.set_feature_dimension(0); + dimension_numbers.set_input_batch_dimension(2); + dimension_numbers.set_input_feature_dimension(0); + dimension_numbers.set_output_batch_dimension(2); + dimension_numbers.set_output_feature_dimension(0); dimension_numbers.add_spatial_dimensions(1); dimension_numbers.add_spatial_dimensions(3); dimension_numbers.set_kernel_output_feature_dimension(0); @@ -374,8 +376,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { // Set the convolution dimension numbers. ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(2); - dimension_numbers.set_feature_dimension(0); + dimension_numbers.set_input_batch_dimension(2); + dimension_numbers.set_input_feature_dimension(0); + dimension_numbers.set_output_batch_dimension(2); + dimension_numbers.set_output_feature_dimension(0); dimension_numbers.add_spatial_dimensions(1); dimension_numbers.add_spatial_dimensions(3); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ee37cc8ad2dab6f46f559f9979e325b652b1260e..1923aaa490284e7a9d90b5d134316a01f67e63ee 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -115,7 +115,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", @@ -579,12 +579,14 @@ cc_library( ":shaped_buffer", ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", ], @@ -717,6 +719,18 @@ cc_library( ], ) +tf_cc_test( + name = "name_uniquer_test", + srcs = ["name_uniquer_test.cc"], + deps = [ + ":name_uniquer", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "liveness_util", srcs = ["liveness_util.cc"], @@ -1051,13 +1065,40 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", ], ) +cc_library( + name = "defuser", + srcs = ["defuser.cc"], + hdrs = ["defuser.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "defuser_test", + srcs = ["defuser_test.cc"], + deps = [ + ":defuser", + ":hlo_matchers", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + ], +) + cc_library( name = "tuple_simplifier", srcs = ["tuple_simplifier.cc"], @@ -1116,7 +1157,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -2050,6 +2091,29 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_runner", + srcs = ["hlo_runner.cc"], + hdrs = ["hlo_runner.h"], + deps = [ + ":executable", + ":hlo", + ":transfer_manager", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//third_party/eigen3", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4858f47c59448c78a208c70f5b71956beca375b1..2a610e91f05e2e8968b3cec89721c7f9bd6ff64c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -98,11 +98,11 @@ bool ReshapeIsBitcast( HloComputation* CreateScalarBinaryComputation(HloModule* module, PrimitiveType primitive_type, HloOpcode opcode) { - HloComputation::Builder b("scalar computation"); + HloComputation::Builder b("scalar_computation"); auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar lhs")); + 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar rhs")); + 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); auto scalar_op = b.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), opcode, scalar_lhs, scalar_rhs)); @@ -141,6 +141,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override; + Status HandleReal(HloInstruction* real, HloInstruction* operand) override; + Status HandleImag(HloInstruction* imag, HloInstruction* operand) override; + Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; @@ -201,17 +204,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { static bool Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification); + bool enable_dot_simplification, bool enable_conv_simplification); private: explicit AlgebraicSimplifierVisitor( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification) + bool enable_dot_simplification, bool enable_conv_simplification) : computation_(computation), is_layout_sensitive_(is_layout_sensitive), valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_simplification_(enable_dot_simplification) {} + enable_dot_simplification_(enable_dot_simplification), + enable_conv_simplification_(enable_conv_simplification) {} // Convenience method for replacing an instruction with a bitcast. void ReplaceWithBitcast(HloInstruction* instruction); @@ -287,15 +291,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot simplication on platforms where it causes a slowdown. bool enable_dot_simplification_; + + // Disable convolution simplication on platforms where it causes a slowdown. + bool enable_conv_simplification_; }; bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification) { - AlgebraicSimplifierVisitor visitor(computation, is_layout_sensitive, - std::move(valid_bitcast_callback), - enable_dot_simplification); + bool enable_dot_simplification, bool enable_conv_simplification) { + AlgebraicSimplifierVisitor visitor( + computation, is_layout_sensitive, std::move(valid_bitcast_callback), + enable_dot_simplification, enable_conv_simplification); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -519,11 +526,16 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, // A/pow(B,C) => A*pow(B,-C) if (rhs->opcode() == HloOpcode::kPower) { VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString(); + // The output shape of the created negate operator should be the same as the + // input. + const Shape& negate_shape = rhs->operand(1)->shape(); HloInstruction* negate = computation_->AddInstruction(HloInstruction::CreateUnary( - divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(1))); + negate_shape, HloOpcode::kNegate, rhs->mutable_operand(1))); + // And the power operator should retain the output shape of the old one. + const Shape& new_power_shape = rhs->shape(); HloInstruction* new_power = computation_->AddInstruction( - HloInstruction::CreateBinary(divide->shape(), HloOpcode::kPower, + HloInstruction::CreateBinary(new_power_shape, HloOpcode::kPower, rhs->mutable_operand(0), negate)); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary( @@ -912,9 +924,10 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // A Broadcast that feeds a unary element-wise operation can sink the // broadcast after the unary element-wise operation. TF_ASSIGN_OR_RETURN( - changed_, + bool sink_succeeded, TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); - if (changed_) { + changed_ |= sink_succeeded; + if (sink_succeeded) { return Status::OK(); } @@ -957,6 +970,24 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { return Status::OK(); } +// Real(Complex(r, i)) -> r +Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real, + HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kComplex) { + return ReplaceInstruction(real, operand->mutable_operand(0)); + } + return Status::OK(); +} + +// Imag(Complex(r, i)) -> i +Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag, + HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kComplex) { + return ReplaceInstruction(imag, operand->mutable_operand(1)); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { // Eliminate nop pads (padding all zero), and replace a pad with negative // padding with a pad with non-negative padding followed by a slice. @@ -1217,9 +1248,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // A Reshape that feeds a unary element-wise operation can sink the // reshape after the unary element-wise operation. TF_ASSIGN_OR_RETURN( - changed_, + bool sink_succeeded, TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape)); - if (changed_) { + changed_ |= sink_succeeded; + if (sink_succeeded) { return Status::OK(); } @@ -1262,6 +1294,11 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( if (ShapeUtil::IsScalar(dynamic_slice->shape())) { return ReplaceInstruction(dynamic_slice, operand); } + // DynamicSlice where operand has the same size as the output and + // start_indices are all zero is simply equal to operand. + if (IsAll(start_indices, 0) && SameShape(operand, dynamic_slice)) { + return ReplaceInstruction(dynamic_slice, operand); + } return Status::OK(); } @@ -1280,8 +1317,7 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( // not to affect the visible behavior of this op even when the indices are out // of range. Currently dynamic-update-slice wraps out-of-range indices, so // we can only remove the op if its indices never wrap.) - if (start_indices->IsConstant() && start_indices->literal().IsAll(0) && - ShapeUtil::Compatible(dynamic_update_slice->shape(), update->shape())) { + if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) { return ReplaceInstruction(dynamic_update_slice, update); } return Status::OK(); @@ -1453,6 +1489,9 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) { + if (!enable_conv_simplification_) { + return Status::OK(); + } // HandleConvolution tries to replace a convolution with a DOT instruction. // // Only add when bitcasts can be used: @@ -1505,7 +1544,10 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // still convert Conv into more efficient Matmul with operand transposition // (such as the transposition flags in cuBLAS SGEMM). if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) || - input_shape.layout().minor_to_major(0) != dnums.feature_dimension() || + input_shape.layout().minor_to_major(0) != + dnums.input_feature_dimension() || + convolution_shape.layout().minor_to_major(0) != + dnums.output_feature_dimension() || // The input feature dimension should come later in the minor-to-major // order. (PositionInContainer(filter_shape.layout().minor_to_major(), @@ -1524,14 +1566,14 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // Replace it with a dot, with bitcasts around it to get the right shape. const int64 input_channels = - input_shape.dimensions(dnums.feature_dimension()); + input_shape.dimensions(dnums.input_feature_dimension()); const int64 output_channels = filter_shape.dimensions(dnums.kernel_output_feature_dimension()); // Computes the product of the non-feature dimensions. int64 conv_width = 1; for (int i = 0; i < input_shape.dimensions_size(); ++i) { - if (i != dnums.feature_dimension()) { + if (i != dnums.input_feature_dimension()) { conv_width *= input_shape.dimensions(i); } } @@ -1782,7 +1824,7 @@ static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { // Tries to determine the number of times the given loop executes. Currently // simply returns 0, 1, or "can't tell" (nullopt). -static optional GetLoopTripCount(const HloInstruction* while_op) { +static optional GetLoopTripCount(HloInstruction* while_op) { CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); VLOG(2) << "Getting trip count for loop " << while_op->ToString(); @@ -1803,15 +1845,10 @@ static optional GetLoopTripCount(const HloInstruction* while_op) { // compute how many times the loop executes. Start by computing the induction // variable's initial value. HloEvaluator evaluator; - auto* while_init = while_op->operand(0); - auto* indvar_init = while_init->operand(*indvar_tuple_idx); - // TODO(b/67157142): This should not be redundant, remove this when the - // underlying issue has been addressed. - if (!hlo_query::AllOperandsAreConstants(*indvar_init)) { - return nullopt; - } + auto* while_init = while_op->mutable_operand(0); + auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init->Clone().get()); + evaluator.Evaluate(indvar_init); if (!indvar_init_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable init: " << indvar_init_result.status(); @@ -1925,7 +1962,7 @@ Status AlgebraicSimplifierVisitor::HandleWhile(HloInstruction* while_op) { return Status::OK(); } - // Remove while loops with static trip count of 1. + // Remove while loops with static trip count of 0. optional trip_count = GetLoopTripCount(while_op); if (trip_count && *trip_count == 0) { // The loop never executes, so the value of the loop is the value of its @@ -1940,8 +1977,10 @@ Status AlgebraicSimplifierVisitor::HandleWhile(HloInstruction* while_op) { changed_ = true; return Status::OK(); } + + // Transform while loops with static trip count of 1 into a call op, then + // inline the call. if (trip_count && *trip_count == 1) { - // Transform the while loop into a call op, then inline the call. auto computation = while_op->parent(); auto call_op = computation->AddInstruction(HloInstruction::CreateCall( while_op->shape(), while_op->operands(), while_op->while_body())); @@ -1958,9 +1997,9 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (AlgebraicSimplifierVisitor::Run(comp, is_layout_sensitive_, - valid_bitcast_callback_, - enable_dot_simplification_)) { + if (AlgebraicSimplifierVisitor::Run( + comp, is_layout_sensitive_, valid_bitcast_callback_, + enable_dot_simplification_, enable_conv_simplification_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 4295a3227a837ffc8483b3be59994c9e6ac96aec..a9f476178c7af74c275a10de7727ea64e17d590f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -40,11 +40,13 @@ class AlgebraicSimplifier : public HloPassInterface { // bitcasts. AlgebraicSimplifier(bool is_layout_sensitive, ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification = true) + bool enable_dot_simplification = true, + bool enable_conv_simplification = true) : is_layout_sensitive_(is_layout_sensitive), valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_simplification_(enable_dot_simplification) {} - ~AlgebraicSimplifier() override {} + enable_dot_simplification_(enable_dot_simplification), + enable_conv_simplification_(enable_conv_simplification) {} + ~AlgebraicSimplifier() override = default; tensorflow::StringPiece name() const override { return "algsimp"; } // Run algebraic simplification on the given computation. Returns whether the @@ -57,6 +59,9 @@ class AlgebraicSimplifier : public HloPassInterface { // Enable dot simplication on platforms where it is profitable. bool enable_dot_simplification_; + + // Enable convolution simplication on platforms where it is profitable. + bool enable_conv_simplification_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index cf97a261da957353de852427b3a7394e5f511d13..87d4fc9663daf3cc2806dfa6550812dd9b08b36c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -47,7 +47,7 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloTestBase { +class AlgebraicSimplifierTest : public HloVerifiedTestBase { public: // Makes a computation that contains a loop that runs num_iters times. HloComputation* MakeSimpleLoop(HloModule* module, int num_iters); @@ -353,6 +353,42 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { op::Multiply(param0, op::Power(param1, op::Negate(param2)))); } +// Test that broadcasting is done on the right step when simplifying A/pow(B,C) +// to A*pow(B,-C). +TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r1f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* power = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2)); + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(param0, op::Power(param1, param2))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + ASSERT_THAT(computation->root_instruction(), + op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + + const HloInstruction* negate = + computation->root_instruction()->operand(1)->operand(1); + const Shape& negate_shape = negate->shape(); + EXPECT_EQ(0, negate_shape.dimensions_size()); +} + // Test that A/1 is simplified to A for a scalar. TEST_F(AlgebraicSimplifierTest, DivOneScalar) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -397,6 +433,56 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { EXPECT_EQ(root, param0); } +// Test that real(complex(r,i)) is simplified to r. +TEST_F(AlgebraicSimplifierTest, RealOfComplex) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r2f32, "param1")); + HloInstruction* cplx = builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64), + HloOpcode::kComplex, param0, param1)); + HloInstruction* real = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, real); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that imag(complex(r,i)) is simplified to i. +TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r2f32, "param1")); + HloInstruction* cplx = builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64), + HloOpcode::kComplex, param0, param1)); + HloInstruction* imag = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, imag); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param1); +} + // Test that get_element(make_tuple({A,B}),1) is simplified to B TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -1077,6 +1163,54 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { op::Maximum(op::Reshape(param), zero)); } +// Regression test for a bug where if we failed to sink a reshape, we'd set the +// 'changed' bit in AlgebraicSimplifier to false. +TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { + HloComputation::Builder builder(TestName()); + + // This add (param0 + 0) can be simplified. + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")), + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{0, 0}, {0, 0}}))))); + + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); +} + +// Regression test for a bug where if we failed to sink a reshape, we'd set the +// 'changed' bit in AlgebraicSimplifier to false. +TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { + HloComputation::Builder builder(TestName()); + + // This add (param0 + 0) can be simplified. + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")), + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{0, 0}, {0, 0}}))))); + + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0})); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -1530,7 +1664,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { for (int i = 0; i < strlen(options.dim_order); ++i) { char ch = options.dim_order[i]; if (ch == 'N') { - dnums.set_batch_dimension(i); + dnums.set_input_batch_dimension(i); + dnums.set_output_batch_dimension(i); in_dims.push_back(options.in_batch); } else if (ch == 'H') { dnums.set_spatial_dimensions(0, i); @@ -1539,7 +1674,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { dnums.set_spatial_dimensions(1, i); in_dims.push_back(options.in_width); } else if (ch == 'C') { - dnums.set_feature_dimension(i); + dnums.set_input_feature_dimension(i); + dnums.set_output_feature_dimension(i); in_dims.push_back(options.in_channels); in_channel_idx = i; } @@ -2041,7 +2177,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloInstruction::CreateConstant(Literal::CreateR1({0.0f}))); HloInstruction* one = call_builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR1({1.0f}))); - builder.AddInstruction( + call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); auto module = CreateNewModule(); @@ -2165,6 +2301,29 @@ TEST_F(AlgebraicSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); } +// A dynamic-slice is trivial if its start indices are all zeroes and the size +// of its input equals the size of its output. In this case, the dynamic slice +// is equal to its input. +TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { + HloComputation::Builder builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); + builder.AddInstruction(HloInstruction::CreateDynamicSlice( + shape, + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "slice_from")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))), + /*slice_sizes=*/{10, 100, 1000})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Parameter()); +} + // A dynamic-update-slice is trivial if its start indices are all zeroes and the // size of its "update" equals the size of its output. In this case, the // dynamic-update-slice is equal to its update. diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc index 427294dfc6fa4a27e28dc0fcb0f726601aa94468..abe881cd1a58a6173b9b93f10a7308d70106c889 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc @@ -83,11 +83,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type, HloOpcode opcode) { - HloComputation::Builder b("scalar computation"); + HloComputation::Builder b("scalar_computation"); auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar lhs")); + 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar rhs")); + 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); auto scalar_op = b.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), opcode, scalar_lhs, scalar_rhs)); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 688aff89125ce3e30be8918a9dfe9f17e22e6243..08a53af8baa3f250919517c87c023c329b129024 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -320,6 +320,13 @@ class BufferAssignment { const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const; + // Returns true if the top-level buffers of hlo_a and hlo_b are the same. + // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b). + bool SharesTopLevelSlice(const HloInstruction* hlo_a, + const HloInstruction* hlo_b) const { + return SharesSliceAtIndex(hlo_a, {}, hlo_b, {}); + } + // Returns the underlying points-to analysis used for this assignment. const TuplePointsToAnalysis& points_to_analysis() const { return liveness_->points_to_analysis(); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index e3378a756b383a17a937f55afcd9ac08fe175fec..89410f42bd7b5fa8f9b380c868fcd4fedb54576c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1179,7 +1179,7 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(3, assignment->Allocations().size()); - // Buffers for call are co-located with the sub-computation. + // Buffers for call are colocated with the sub-computation. EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}), GetAllocation(*assignment, sub_tuple, /*index=*/{})); EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}), @@ -1238,7 +1238,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { auto assignment = RunBufferAssignment(module.get()); - // Buffers for call are co-located with the sub-computations. + // Buffers for call are colocated with the sub-computations. EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}), GetAllocation(*assignment, b_call, /*index=*/{})); EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}), diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index d5bd9214be44f4abd5f672168335ae1a259c9118..4c2d9600d909e82dcb62f508a10445c08c1cdee6 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -114,7 +114,8 @@ class Compiler { // sequence of executable objects. virtual StatusOr>> Compile( std::vector> modules, - std::vector stream_exec) = 0; + std::vector> + stream_exec) = 0; // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index fa6e5b231376933aab3381c922651f6091ec023a..ef8eed3f88c3d557fcb4ec5b9e1988ce82b777e8 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -48,6 +48,29 @@ cc_library( alwayslink = True, # Contains per-platform transfer manager registration ) +cc_library( + name = "external_constant_pool", + srcs = ["external_constant_pool.cc"], + hdrs = ["external_constant_pool.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "external_constant_pool_test", + srcs = ["external_constant_pool_test.cc"], + deps = [ + ":external_constant_pool", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "cpu_compiler", srcs = ["cpu_compiler.cc"], @@ -64,6 +87,7 @@ cc_library( ":ir_emitter", ":layout_assignment", ":parallel_cpu_executable", + ":parallel_task_assignment", ":simple_orc_jit", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -122,15 +146,17 @@ cc_library( name = "simple_orc_jit", srcs = ["simple_orc_jit.cc"], hdrs = ["simple_orc_jit.h"], - linkopts = ["-ldl"], deps = [ ":compiler_functor", ":cpu_runtime", ":cpu_runtime_avx", ":cpu_runtime_neon", ":cpu_runtime_sse4_1", + ":custom_call_target_registry", ":disassembler", + ":external_constant_pool", ":runtime_conv2d", + ":runtime_fork_join", ":runtime_matmul", ":runtime_single_threaded_conv2d", ":runtime_single_threaded_matmul", @@ -217,7 +243,9 @@ cc_library( ":cpu_options", ":cpu_runtime", ":dot_op_emitter", + ":external_constant_pool", ":ir_emission_utils", + ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -238,6 +266,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -479,9 +508,24 @@ cc_library( ], ) +cc_library( + name = "runtime_fork_join", + srcs = ["runtime_fork_join.cc"], + hdrs = ["runtime_fork_join.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//third_party/eigen3", + ], +) + tf_cc_test( name = "cpu_runtime_test", srcs = ["cpu_runtime_test.cc"], + tags = ["optonly"], deps = [ ":cpu_runtime", ":runtime_matmul", @@ -662,6 +706,7 @@ cc_library( ":shape_partition", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", + "//tensorflow/compiler/xla/service:hlo_pass", ], ) @@ -674,6 +719,17 @@ cc_library( ], ) +cc_library( + name = "custom_call_target_registry", + srcs = [ + "custom_call_target_registry.cc", + ], + hdrs = [ + "custom_call_target_registry.h", + ], + visibility = ["//visibility:public"], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 069979c6611e90ed2d95cbbe341198577cdf56cf..44cd2171afdc6eecc22f3f920276a4d95f930573 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -36,8 +36,8 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { !PotentiallyImplementedAsEigenConvolution(*hlo)) { const ConvolutionDimensionNumbers& dnums = hlo->convolution_dimension_numbers(); - auto batch_dim = dnums.batch_dimension(); - auto feature_dim = dnums.feature_dimension(); + auto input_batch_dim = dnums.input_batch_dimension(); + auto input_feature_dim = dnums.input_feature_dimension(); auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension(); auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension(); @@ -59,15 +59,16 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { std::vector new_input_dim_order(num_dims); std::vector new_input_dims(num_dims); - new_input_dim_order[0] = batch_dim; - new_input_dims[0] = input->shape().dimensions(batch_dim); + new_input_dim_order[0] = input_batch_dim; + new_input_dims[0] = input->shape().dimensions(input_batch_dim); for (int i = 0; i < num_spatial_dims; ++i) { new_input_dim_order[i + 1] = dnums.spatial_dimensions(i); new_input_dims[i + 1] = input->shape().dimensions(dnums.spatial_dimensions(i)); } - new_input_dim_order[num_dims - 1] = feature_dim; - new_input_dims[num_dims - 1] = input->shape().dimensions(feature_dim); + new_input_dim_order[num_dims - 1] = input_feature_dim; + new_input_dims[num_dims - 1] = + input->shape().dimensions(input_feature_dim); Shape new_input_shape = ShapeUtil::MakeShape(input->shape().element_type(), new_input_dims); @@ -98,22 +99,26 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { new_kernel_dim_order)); std::vector new_conv_dims(num_dims); - new_conv_dims[0] = hlo->shape().dimensions(batch_dim); + auto output_batch_dim = dnums.output_batch_dimension(); + auto output_feature_dim = dnums.output_feature_dimension(); + new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim); for (int i = 0; i < num_spatial_dims; ++i) { new_conv_dims[i + 1] = hlo->shape().dimensions(dnums.spatial_dimensions(i)); } - new_conv_dims[num_dims - 1] = hlo->shape().dimensions(feature_dim); + new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim); Shape new_conv_shape = ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims); ConvolutionDimensionNumbers new_dnums; - new_dnums.set_batch_dimension(0); + new_dnums.set_input_batch_dimension(0); + new_dnums.set_output_batch_dimension(0); for (int i = 0; i < num_spatial_dims; ++i) { new_dnums.add_spatial_dimensions(i + 1); new_dnums.add_kernel_spatial_dimensions(i); } - new_dnums.set_feature_dimension(num_dims - 1); + new_dnums.set_input_feature_dimension(num_dims - 1); + new_dnums.set_output_feature_dimension(num_dims - 1); new_dnums.set_kernel_input_feature_dimension(num_dims - 2); new_dnums.set_kernel_output_feature_dimension(num_dims - 1); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 9e8b785f30559f493bcec546e0612f2290af031d..d593ba26b655d00a0f0f0b9a94c9e62fa1835080 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -67,10 +67,12 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize)))); ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(1); + dnums.set_input_batch_dimension(1); + dnums.set_output_batch_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); - dnums.set_feature_dimension(0); + dnums.set_input_feature_dimension(0); + dnums.set_output_feature_dimension(0); dnums.add_kernel_spatial_dimensions(2); dnums.add_kernel_spatial_dimensions(3); dnums.set_kernel_input_feature_dimension(1); @@ -121,10 +123,12 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount)))); ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); dnums.add_kernel_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(2); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 2ad357896960b2a52436031b593dd08269c593d1..65e117e68f5b82345fc7d4317fd8b753e47670df 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -58,6 +58,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -248,7 +249,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* module) { +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Optimization pipeline. HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(ShapeSizeBytesFunction()); @@ -269,6 +270,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { { auto& pass = pipeline.AddPass>("simplification"); + pass.AddInvariantChecker(ShapeSizeBytesFunction()); + pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, @@ -279,6 +282,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { [](const Shape&, const Shape&) { return false; }, /*enable_dot_simplification=*/false); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); } @@ -314,6 +318,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { if (options::CpuParallelBackendRequested(module->config())) { pipeline.AddPass(max_parallelism, ShapeSizeBytesFunction()); + } else if (!is_aot_compile) { + // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module. + // Note this is not run for AOT because it would bring in thread pool + // and thread synchronization dependencies which would likely increase + // binary size (and most AOT applications are single-threaded). + // TODO(29630486) Support multi-threaded AOT. + pipeline.AddPass(max_parallelism, + ShapeSizeBytesFunction(), module); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -448,7 +460,7 @@ StatusOr> CpuCompiler::Compile( llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get())); + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); HloComputation* computation = module->entry_computation(); std::unordered_map hlo_to_profile_idx; @@ -464,8 +476,8 @@ StatusOr> CpuCompiler::Compile( // ownership is std::moved. const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); if (options::CpuParallelBackendRequested(module->config())) { VLOG(1) << "Using parallel cpu backend"; @@ -485,10 +497,10 @@ StatusOr> CpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!dump_debug_json_to.empty()) { + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } // If we are using the parallel CPU backend, we need to create map from @@ -522,7 +534,8 @@ StatusOr> CpuCompiler::Compile( } IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine()); + &hlo_to_profile_idx, jit->target_machine(), + jit->external_constant_pool()); std::unique_ptr> function_names( new std::map()); @@ -591,18 +604,18 @@ StatusOr> CpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!dump_debug_json_to.empty()) { + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } - // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine()); + &hlo_to_profile_idx, jit->target_machine(), + jit->external_constant_pool()); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -649,7 +662,7 @@ StatusOr> CpuCompiler::Compile( StatusOr>> CpuCompiler::Compile( std::vector> modules, - std::vector stream_execs) { + std::vector> stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on CPU."); } @@ -745,7 +758,13 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, HloModule* module = modules[i].get(); VLOG(1) << "Compiling ahead-of-time: " << module->name(); - TF_RETURN_IF_ERROR(RunHloPasses(module)); + VLOG(2) << "Before optimization:"; + XLA_VLOG_LINES(2, module->ToString()); + + TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true)); + + VLOG(2) << "After optimization:"; + XLA_VLOG_LINES(2, module->ToString()); TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, @@ -762,16 +781,17 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); - if (!dump_debug_json_to.empty()) { + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } IrEmitter ir_emitter(*module, *assignment, &llvm_module, - /*hlo_to_profile_idx=*/nullptr, target_machine.get()); + /*hlo_to_profile_idx=*/nullptr, target_machine.get(), + /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index bd3541500dae9d9d59c56bfb062912a1b85c2219..d09130247421b11d6d4879466f39b89167eb9564 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -115,7 +115,8 @@ class CpuCompiler : public LLVMCompiler { StatusOr>> Compile( std::vector> modules, - std::vector stream_exec) override; + std::vector> + stream_execs) override; StatusOr>> CompileAheadOfTime(std::vector> modules, @@ -131,7 +132,7 @@ class CpuCompiler : public LLVMCompiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* module); + Status RunHloPasses(HloModule* module, bool is_aot_compile); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc index 2cd0aa788057d585c2a60bd03f596b129cc53554..662ee609232f5582ce74f4f515637b2623175e94 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -116,26 +116,6 @@ StatusOr ParallelizationPreparation::RunParallelTaskAssignment( // Assign parallel tasks to HLOs in entry computation. HloComputation* computation = module->entry_computation(); for (auto* instruction : computation->instructions()) { - // Currently, we do not assign parallel tasks to instructions with at least - // one of the following properties: - // *) Internal threading (library calls to kConv, kDot, and kCustomCall). - // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). - // *) Tuple-shaped. - // TODO(b/27458679) Parallelize instructions which are skipped here. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kCall || - instruction->opcode() == HloOpcode::kCustomCall || - instruction->opcode() == HloOpcode::kSelectAndScatter || - (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) || - PotentiallyImplementedAsEigenDot(*instruction) || - (instruction->opcode() == HloOpcode::kFusion && - instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || - ShapeUtil::IsTuple(instruction->shape())) { - continue; - } - // Calculate target parallel task count in [1, max_parallelism_]. const int64 target_parallel_task_count = parallel_task_assignment.GetTargetParallelTaskCount(instruction); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index c7155b858bda5e5640e9a6719fb394ca1360d128..7908dc173d79a4a9dcb6127ac344267e27d2b5f2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -51,6 +51,9 @@ extern const char* const kAcquireOutfeedBufferForPopulationSymbolName = "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation"; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; +extern const char* const kParallelForkJoinSymbolName = + "__xla_cpu_runtime_ParallelForkJoin"; + extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 29feb7267fe97f6876827b6cbfa6217a0cecf238..2ade455b8a0a43dda8c93bbb79891439da2e4f75 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -51,6 +51,7 @@ extern const char* const kAcquireInfeedBufferForDequeueSymbolName; extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; +extern const char* const kParallelForkJoinSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f5803874b7886e56da47250d0dbe297f5db16c5 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" + +namespace xla { +namespace cpu { + +CustomCallTargetRegistry* CustomCallTargetRegistry::Global() { + static auto* registry = new CustomCallTargetRegistry; + return registry; +} + +void CustomCallTargetRegistry::Register(const std::string& symbol, + void* address) { + std::lock_guard lock(mu_); + registered_symbols_[symbol] = address; +} + +void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const { + std::lock_guard lock(mu_); + auto it = registered_symbols_.find(symbol); + return it == registered_symbols_.end() ? nullptr : it->second; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..2994642356d55df26c31553ef28dc653503d05be --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ + +// This file is depended on by kernels that have to build for mobile devices. +// For this reason, we avoid relying on TensorFlow and instead only use the +// standard C++ library. + +#include // NOLINT +#include +#include + +namespace xla { +namespace cpu { + +// The CPU JIT compiler uses this registry to resolve symbolic CustomCall +// targets; so when using the CPU JIT, CustomCall targets need to be registered +// here with the symbol name used in the CustomCall. +// +// The XLA AOT compiler links using a standard offline linker; so when compiling +// in AOT mode, you *also* need to make sure the name of the callee (presumably +// implemented in C++) matches up with the symbolic name used in the CustomCall. +// +// We maintain the registry in both the JIT and the AOT cases for simplicity, +// but we only use it when running in JIT mode. +class CustomCallTargetRegistry { + public: + static CustomCallTargetRegistry* Global(); + + void Register(const std::string& symbol, void* address); + void* Lookup(const std::string& symbol) const; + + private: + std::unordered_map registered_symbols_; + mutable std::mutex mu_; +}; + +class RegisterCustomCallTarget { + public: + explicit RegisterCustomCallTarget(const std::string& name, void* address) { + CustomCallTargetRegistry::Global()->Register(name, address); + } +}; + +#define REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b + +#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, counter) \ + static ::xla::cpu::RegisterCustomCallTarget REGISTER_CUSTOM_CALL_CONCAT( \ + custom_call_target_register, counter)(symbol, \ + reinterpret_cast(address)) + +#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \ + REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, __COUNTER__) + +#define REGISTER_CUSTOM_CALL_TARGET(function) \ + REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function) + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index d3b94d75411218346cd25b0d3ecc3a9f30b56ba3..e57d49172b18beb75cfbb482c5d732ef679ebe41 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -63,7 +63,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config) { PrimitiveType type = target_array.GetShape().element_type(); - TF_RET_CHECK(F32 == type || F64 == type); + TF_RET_CHECK(F32 == type || F64 == type || C64 == type); DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, lhs_array, rhs_array, executable_run_options_value, ir_builder, hlo_module_config); @@ -176,7 +176,7 @@ tensorflow::Status DotOpEmitter::Emit() { llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock(); ir_builder_->SetInsertPoint(preheader_bb->getTerminator()); - ir_builder_->CreateStore(llvm::ConstantFP::get(accum_type, 0.0), + ir_builder_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address); // Body basic block of reduction loop: @@ -191,9 +191,29 @@ tensorflow::Status DotOpEmitter::Emit() { llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_); - llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element); llvm::Value* accum = ir_builder_->CreateLoad(accum_address); - llvm::Value* updated_accum = ir_builder_->CreateFAdd(accum, product); + llvm::Value* updated_accum; + if (ShapeUtil::ElementIsComplex(lhs_shape)) { + auto real = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {1}); + }; + llvm::Value* product_real = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(real(lhs_element), real(rhs_element)), + ir_builder_->CreateFMul(imag(lhs_element), imag(rhs_element))); + llvm::Value* product_imag = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(lhs_element), imag(rhs_element)), + ir_builder_->CreateFMul(imag(lhs_element), real(rhs_element))); + updated_accum = ir_builder_->CreateInsertValue( + accum, ir_builder_->CreateFAdd(real(accum), product_real), {0}); + updated_accum = ir_builder_->CreateInsertValue( + updated_accum, ir_builder_->CreateFAdd(imag(accum), product_imag), {1}); + } else { + llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element); + updated_accum = ir_builder_->CreateFAdd(accum, product); + } ir_builder_->CreateStore(updated_accum, accum_address); // Exit basic block of reduction loop. @@ -230,11 +250,28 @@ tensorflow::Status DotOpEmitter::Emit() { tensorflow::Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. + llvm::Value* result; llvm::Value* lhs_value = lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); llvm::Value* rhs_value = rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); - llvm::Value* result = ir_builder_->CreateFMul(lhs_value, rhs_value); + if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) { +#define REAL(x) ir_builder_->CreateExtractValue(x, {0}) +#define IMAG(x) ir_builder_->CreateExtractValue(x, {1}) + llvm::Value* real = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(REAL(lhs_value), REAL(rhs_value)), + ir_builder_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value))); + llvm::Value* imag = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)), + ir_builder_->CreateFMul(IMAG(lhs_value), REAL(rhs_value))); +#undef IMAG +#undef REAL + result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType()); + result = ir_builder_->CreateInsertValue(result, real, {0}); + result = ir_builder_->CreateInsertValue(result, imag, {1}); + } else { + result = ir_builder_->CreateFMul(lhs_value, rhs_value); + } target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index 73e039250ba62b1313c98965421f6d823ca6a3b0..ba693ec89ab7c4090f8c9d1e4d65f17a80d0ac55 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -46,8 +46,8 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( } // Create function type for the function. llvm::FunctionType* function_type = llvm::FunctionType::get( - llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_), - llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_), + llvm_ir::PrimitiveTypeToIrType(element_type, module_), + llvm_ir::PrimitiveTypeToIrType(element_type, module_), /*isVarArg=*/false); // Create function declaration for 'tanhf'. llvm::Function* function = diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9f8e5584965d0c73771750e26bd63c401d5b0c0 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { +namespace cpu { +void ExternalConstantPool::Insert(string name, const Literal& literal, + int64 alignment) { + CHECK(!ShapeUtil::IsTuple(literal.shape())); + CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); + CHECK(entries_.find(name) == entries_.end()); + + int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); + void* raw_pointer; + CHECK_EQ( + posix_memalign(&raw_pointer, std::max(alignment, sizeof(void*)), + literal_size), + 0) + << "failed to allocate " << literal_size << " bytes with alignment of " + << alignment; + + std::memcpy(raw_pointer, literal.InternalData(), literal_size); + entries_.emplace(std::move(name), static_cast(raw_pointer)); +} + +const uint8* ExternalConstantPool::Find(const string& name) { + auto it = entries_.find(name); + return it == entries_.end() ? nullptr : it->second.get(); +} +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..ade28cbcbcfda05a9ad0adab1139bf316720e11f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ + +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { +namespace cpu { +// An ExternalConstantPool maintains a set of constants kept external to +// generated LLVM IR. These constants are accessed from the IR via globals with +// extern linkage. This current incarnation of ExternalConstantPool only +// supports the JIT CPU backend; the AOT backend is not supported. +// +// Implementation-wise, this is a simple wrapper around a map of strings to byte +// buffers. This simply implementation works in a JIT scenario. This class +// will have to become smarter if we decide to support external constant pools +// on AOT compiles in the future. +class ExternalConstantPool { + public: + // Inserts a buffer with the contents of `literal` into the constant pool with + // the name `name`. It is an error to try to insert two constants with the + // same `name` into the same constant pool. The buffer for literal is aligned + // to `aligment` bytes, and `alignment` must be a power of 2. + // + // The constant pool copies out the contents of `literal` into a buffer it + // owns -- it does not keep pointers to `literal`, or to memory owned by + // `literal`. + void Insert(string name, const Literal& literal, int64 alignment); + + // Find the constant with name `name` in this constant pool. If there isn't + // such constant, return nullptr. + const uint8* Find(const string& name); + + private: + // We need to `free()` pointers allocated into `entries_` since we allocate + // them with `posix_memalign`. + struct FreeDeleter { + void operator()(void* ptr) { free(ptr); } + }; + + tensorflow::gtl::FlatMap> + entries_; +}; +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9290a4e5dfc03ddb86e9d82f1f0f4f9a8ceebb88 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { +class ExternalConstantPoolTest : public ::testing::Test {}; + +template +T GetFromBuffer(const uint8* buffer, int64 index) { + T result; + std::memcpy(&result, buffer + index * sizeof(T), sizeof(T)); + return result; +} + +TEST(ExternalConstantPoolTest, Basic) { + ExternalConstantPool constant_pool; + EXPECT_EQ(constant_pool.Find("name-0"), nullptr); + const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); + constant_pool.Insert("name-0", *literal, 4); + const uint8* constant = constant_pool.Find("name-0"); + ASSERT_NE(constant, nullptr); + + EXPECT_EQ(GetFromBuffer(constant, 0), 1); + EXPECT_EQ(GetFromBuffer(constant, 1), 2); + EXPECT_EQ(GetFromBuffer(constant, 2), 3); + EXPECT_EQ(GetFromBuffer(constant, 3), 4); + + EXPECT_EQ(constant_pool.Find("name-1"), nullptr); +} + +TEST(ExternalConstantPoolTest, RowMinorLayout) { + ExternalConstantPool constant_pool; + EXPECT_EQ(constant_pool.Find("name-0"), nullptr); + const auto literal = Literal::CreateR2WithLayout( + {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); + constant_pool.Insert("name-0", *literal, 4); + const uint8* constant = constant_pool.Find("name-0"); + ASSERT_NE(constant, nullptr); + + EXPECT_EQ(GetFromBuffer(constant, 0), 1); + EXPECT_EQ(GetFromBuffer(constant, 1), 3); + EXPECT_EQ(GetFromBuffer(constant, 2), 2); + EXPECT_EQ(GetFromBuffer(constant, 3), 4); +} + +TEST(ExternalConstantPoolTest, Alignment) { + ExternalConstantPool constant_pool; + EXPECT_EQ(constant_pool.Find("name-0"), nullptr); + + for (int i = 0; i < 8; i++) { + int64 alignment = 1 << i; + string name = tensorflow::strings::StrCat("name-", i); + + const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); + constant_pool.Insert(name, *literal, alignment); + + const uint8* constant = constant_pool.Find(name); + ASSERT_NE(constant, nullptr); + EXPECT_EQ(reinterpret_cast(constant) % alignment, 0); + } +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 91b09f2472e4001d8df8aa1ce4dc2796af2a32e7..d72abede022cbe771d126273cb6ff8b8e18cbb43 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -41,6 +41,12 @@ bool PotentiallyImplementedAsEigenConvolution( ShapeUtil::HasZeroElements(kernel_shape)) { return false; } + // TODO(b/65408531): Explore using Eigen dot for complex64 type. + if (ShapeUtil::ElementIsComplex(input_shape) || + ShapeUtil::ElementIsComplex(kernel_shape)) { + return false; + } + const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); // Only 1D and 2D convolutions are supported at the moment. @@ -55,8 +61,12 @@ bool PotentiallyImplementedAsEigenConvolution( std::is_sorted(dnums.kernel_spatial_dimensions().begin(), dnums.kernel_spatial_dimensions().end()); - return dnums.batch_dimension() == 0 && - dnums.feature_dimension() == input_shape.dimensions_size() - 1 && + const Shape& output_shape = convolution.shape(); + return dnums.input_batch_dimension() == 0 && + dnums.input_feature_dimension() == input_shape.dimensions_size() - 1 && + dnums.output_batch_dimension() == 0 && + dnums.output_feature_dimension() == + output_shape.dimensions_size() - 1 && input_spatial_dims_ascending == kernel_spatial_dims_ascending && dnums.kernel_input_feature_dimension() == kernel_shape.dimensions_size() - 2 && diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 2a952328a72068884c28ef74223d454148c7f48d..fa3b3ab8e72a27beab5a6a4b53beb2f73d4a1ddc 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -49,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -75,7 +77,8 @@ IrEmitter::IrEmitter( const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, const std::unordered_map* hlo_to_profile_idx, - llvm::TargetMachine* target_machine) + llvm::TargetMachine* target_machine, + ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), @@ -86,7 +89,8 @@ IrEmitter::IrEmitter( parallel_cpu_backend_( options::CpuParallelBackendRequested(hlo_module_config_)), is_top_level_computation_(false), - target_machine_features_(target_machine) { + target_machine_features_(target_machine), + external_constant_pool_(external_constant_pool) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_enable_fast_math())); @@ -183,20 +187,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { // Even though the type of params and temps is void** in the host's view, in // LLVM IR this is represented by i8*, similarly to void*. It's up to the code // to use GEPs to unravel the indirection layers. - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); - llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); - std::vector compute_function_params( - {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); - if (IsParallelContext()) { - compute_function_params.push_back(i64_ptr_type); - } - if (hlo_to_profile_idx_) { - compute_function_params.push_back(i64_ptr_type); - } llvm::FunctionType* compute_function_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/compute_function_params, + /*Params=*/GetComputeFunctionParams(), /*isVarArg=*/false); // Functions with local linkage get an inlining bonus. Because we know @@ -218,7 +211,7 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { (++arg_iter)->setName("run_options"); (++arg_iter)->setName("params"); (++arg_iter)->setName("temps"); - if (IsParallelContext()) { + if (num_dynamic_loop_bounds_ > 0) { (++arg_iter)->setName("dynamic_loop_bounds"); } if (hlo_to_profile_idx_) { @@ -272,15 +265,39 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { Status IrEmitter::HandleConstant(HloInstruction* constant, const Literal& literal) { VLOG(2) << "HandleConstant: " << constant->ToString(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/initializer->getType(), - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/initializer, - /*Name=*/""); + llvm::GlobalVariable* global_for_const; + + // We avoid creating large constants in the LLVM IR since LLVM is not + // efficient for large constant arrays. We still emit "small enough" constant + // arrays into the Ir, in the off chance the LLVM optimizer can do something + // interesting with it. + const int kMaxInternalConstantSizeInBytes = 128; + if (external_constant_pool_ && + ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { + string global_name = tensorflow::strings::StrCat( + "constant_global_", external_global_constant_counter_++); + global_for_const = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/IrShapeType(literal.shape()), + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/nullptr, + /*Name=*/AsStringRef(global_name)); + global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + external_constant_pool_->Insert(global_name, literal, + MinimumAlignmentForShape(literal.shape())); + } else { + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + global_for_const = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/initializer->getType(), + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/initializer, + /*Name=*/""); + global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + } emitted_value_[constant] = global_for_const; VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); VLOG(2) << " its type: " @@ -291,8 +308,7 @@ Status IrEmitter::HandleConstant(HloInstruction* constant, Status IrEmitter::HandleCopy(HloInstruction* copy) { if (ShapeUtil::IsTuple(copy->shape())) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. - TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy)); - emitted_value_[copy] = copy_value; + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); return EmitMemcpy(*(copy->operand(0)), *copy); } else { // Use the elemental emitter for non-tuple shapes. @@ -385,7 +401,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element, const Shape& shape = get_tuple_element->shape(); emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement( shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape), - GetEmittedValueFor(operand), &ir_builder_); + GetEmittedValueFor(operand), &ir_builder_, module_); return Status::OK(); } @@ -395,12 +411,10 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred, TF_RET_CHECK(pred->shape().element_type() == PRED); if (ShapeUtil::IsTuple(select->shape())) { - TF_ASSIGN_OR_RETURN(llvm::Value * output_address, - EmitTargetAddressForOp(select)); - emitted_value_[select] = output_address; - llvm_ir::EmitTupleSelect(GetIrArrayForOp(select), GetIrArrayForOp(pred), - GetEmittedValueFor(on_true), - GetEmittedValueFor(on_false), &ir_builder_); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select)); + llvm_ir::EmitTupleSelect( + GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), + GetEmittedValueFor(on_false), &ir_builder_, module_); return Status::OK(); } @@ -414,8 +428,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { // The infeed operation produces data (dequeued from the infeed queue) at this // address, which has been provided by buffer assignment. - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(infeed)); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed)); + llvm_ir::IrArray infeed_array = GetIrArrayFor(infeed); if (ShapeUtil::IsTuple(shape)) { TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape)); @@ -433,9 +447,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { ShapeUtil::GetTupleElementShape(shape, i); // Only the outer tuple buffer's target address is obtained from - // EmitTargetAddressForOp to handle the case when Infeed is the - // root instruction. Target addresses for internal elements can - // be obtained from EmitTempBufferPointer. + // GetEmittedValueFor, to handle the case when Infeed is the root + // instruction. Target addresses for internal elements can be obtained + // from EmitTempBufferPointer. llvm::Value* tuple_element_address = EmitTempBufferPointer(buffer, tuple_element_shape); @@ -445,15 +459,13 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { tuple_element_addresses.push_back(tuple_element_address); } - llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, shape), - tuple_element_addresses, &ir_builder_); + llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_, + module_); } else { - TF_RETURN_IF_ERROR( - EmitXfeedTransfer(XfeedKind::kInfeed, shape, target_address)); + TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape, + GetEmittedValueFor(infeed))); } - emitted_value_[infeed] = target_address; - return Status::OK(); } @@ -551,7 +563,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { ShapeUtil::GetTupleElementShape(operand_shape, i); llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement( tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape), - value, &ir_builder_); + value, &ir_builder_, module_); TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed, tuple_element_shape, tuple_element)); } @@ -567,15 +579,12 @@ Status IrEmitter::HandleSort(HloInstruction* sort, HloInstruction* operand) { Status IrEmitter::HandleTuple( HloInstruction* tuple, tensorflow::gtl::ArraySlice operands) { - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(tuple)); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple)); std::vector base_ptrs; for (auto operand : operands) { base_ptrs.push_back(GetEmittedValueFor(operand)); } - llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, tuple->shape()), - base_ptrs, &ir_builder_); - emitted_value_[tuple] = target_address; + llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_, module_); return Status::OK(); } @@ -590,7 +599,7 @@ Status IrEmitter::HandleMap( const llvm_ir::IrArray::Index& index) { std::vector parameter_addresses; for (const HloInstruction* operand : operands) { - const llvm_ir::IrArray& array = GetIrArrayForOp(operand); + const llvm_ir::IrArray& array = GetIrArrayFor(operand); parameter_addresses.push_back( array.EmitArrayElementAddress(index, &ir_builder_)); } @@ -636,7 +645,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window, // the initial value on the reduce_window. PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), "reduce_window_accumulator_address", &ir_builder_, MinimumAlignmentForPrimitiveType(operand_element_type)); ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor( @@ -686,7 +695,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window, SetToFirstInsertPoint(if_data.true_block, &ir_builder_); // We are not in the padding, so carry out the computation. - llvm_ir::IrArray input_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray input_array(GetIrArrayFor(operand)); llvm::Value* input_value_address = input_array.EmitArrayElementAddress(input_index, &ir_builder_); llvm::Value* result = EmitElementFunctionCall( @@ -761,7 +770,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // Allocate space to keep the currently selected value, its index, and // the boolean initialized_flag, which is initially set to false. llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), "selected_value_address", &ir_builder_, MinimumAlignmentForPrimitiveType(operand_element_type)); llvm::Value* selected_index_address = @@ -823,7 +832,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); } }; - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); ir_builder_.CreateStore(operand_data, selected_value_address); @@ -843,8 +852,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. llvm::Value* cond = ir_builder_.CreateICmpNE( - result, llvm::ConstantInt::get( - llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0), + result, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); @@ -866,10 +875,10 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { selected_index.push_back( ir_builder_.CreateLoad(selected_index_address_slot)); } - llvm_ir::IrArray source_array(GetIrArrayForOp(source)); + llvm_ir::IrArray source_array(GetIrArrayFor(source)); llvm::Value* source_value_address = source_array.EmitArrayElementAddress(source_index, &ir_builder_); - llvm_ir::IrArray output_array(GetIrArrayForOp(select_and_scatter)); + llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); llvm::Value* output_value_address = output_array.EmitArrayElementAddress(selected_index, &ir_builder_); llvm::Value* scatter_value = EmitElementFunctionCall( @@ -887,16 +896,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F32, F64})); + /*supported_types=*/{F32, F64, C64})); - llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs)); - llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs)); + llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); + llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); - Shape target_shape = dot->shape(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(dot)); - llvm_ir::IrArray target_array(target_address, target_shape); - AddAliasingInformationToIrArray(*dot, &target_array); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot)); + llvm_ir::IrArray target_array = GetIrArrayFor(dot); VLOG(2) << "HandleDot: "; VLOG(2) << " lhs operand: " @@ -907,13 +913,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs, << llvm_ir::DumpToString(*target_array.GetBasePointer()); // Dot operation is complicated so we delegate to a helper class. - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( + return DotOpEmitter::EmitDotOperation( *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_)); - - emitted_value_[dot] = target_address; - return Status::OK(); + hlo_module_config_); } Status IrEmitter::HandleConvolution(HloInstruction* convolution, @@ -921,7 +924,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, const Window& window) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F32})); + /*supported_types=*/{F32, C64})); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); @@ -941,21 +944,21 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, bool one_dim_convolution = lhs_shape.dimensions_size() == 3; llvm::Value* lhs_address = GetEmittedValueFor(lhs); llvm::Value* rhs_address = GetEmittedValueFor(rhs); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(convolution)); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution)); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); // Input tensor. const Shape& input_shape = convolution->operand(0)->shape(); - int64 input_batch = input_shape.dimensions(dnums.batch_dimension()); + int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension()); int64 input_rows = input_shape.dimensions(dnums.spatial_dimensions(0)); int64 input_cols = one_dim_convolution ? 1 : input_shape.dimensions(dnums.spatial_dimensions(1)); - int64 input_channels = input_shape.dimensions(dnums.feature_dimension()); + int64 input_channels = + input_shape.dimensions(dnums.input_feature_dimension()); // Kernel tensor. const Shape& kernel_shape = convolution->operand(1)->shape(); @@ -1024,35 +1027,33 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); ir_builder_.CreateCall( - conv_func, - { - GetExecutableRunOptionsArgument(), - ir_builder_.CreateBitCast(target_address, float_ptr_type), - ir_builder_.CreateBitCast(lhs_address, float_ptr_type), - ir_builder_.CreateBitCast(rhs_address, float_ptr_type), - ir_builder_.getInt64(input_batch), - ir_builder_.getInt64(input_rows), - ir_builder_.getInt64(input_cols), - ir_builder_.getInt64(input_channels), - ir_builder_.getInt64(kernel_rows), - ir_builder_.getInt64(kernel_cols), - ir_builder_.getInt64(kernel_channels), - ir_builder_.getInt64(kernel_filters), - ir_builder_.getInt64(output_rows), - ir_builder_.getInt64(output_cols), - ir_builder_.getInt64(row_stride), - ir_builder_.getInt64(col_stride), - ir_builder_.getInt64(padding_top), - ir_builder_.getInt64(padding_bottom), - ir_builder_.getInt64(padding_left), - ir_builder_.getInt64(padding_right), - ir_builder_.getInt64(lhs_row_dilation), - ir_builder_.getInt64(lhs_col_dilation), - ir_builder_.getInt64(rhs_row_dilation), - ir_builder_.getInt64(rhs_col_dilation), - }); - target_address->setName(AsStringRef(IrName(convolution))); - emitted_value_[convolution] = target_address; + conv_func, { + GetExecutableRunOptionsArgument(), + ir_builder_.CreateBitCast( + GetEmittedValueFor(convolution), float_ptr_type), + ir_builder_.CreateBitCast(lhs_address, float_ptr_type), + ir_builder_.CreateBitCast(rhs_address, float_ptr_type), + ir_builder_.getInt64(input_batch), + ir_builder_.getInt64(input_rows), + ir_builder_.getInt64(input_cols), + ir_builder_.getInt64(input_channels), + ir_builder_.getInt64(kernel_rows), + ir_builder_.getInt64(kernel_cols), + ir_builder_.getInt64(kernel_channels), + ir_builder_.getInt64(kernel_filters), + ir_builder_.getInt64(output_rows), + ir_builder_.getInt64(output_cols), + ir_builder_.getInt64(row_stride), + ir_builder_.getInt64(col_stride), + ir_builder_.getInt64(padding_top), + ir_builder_.getInt64(padding_bottom), + ir_builder_.getInt64(padding_left), + ir_builder_.getInt64(padding_right), + ir_builder_.getInt64(lhs_row_dilation), + ir_builder_.getInt64(lhs_col_dilation), + ir_builder_.getInt64(rhs_row_dilation), + ir_builder_.getInt64(rhs_col_dilation), + }); return Status::OK(); } @@ -1072,14 +1073,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, for (int i = 0; i < num_spatial_dims; ++i) { output_spatial[i] = index[dnums.spatial_dimensions(i)]; } - llvm::Value* output_feature = index[dnums.feature_dimension()]; - llvm::Value* batch = index[dnums.batch_dimension()]; + llvm::Value* output_feature = index[dnums.output_feature_dimension()]; + llvm::Value* batch = index[dnums.output_batch_dimension()]; // We will accumulate the products into this sum to calculate // the output entry at the given index. PrimitiveType lhs_element_type = lhs->shape().element_type(); llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(lhs_element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_), "convolution_sum_address", &ir_builder_, MinimumAlignmentForPrimitiveType(lhs_element_type)); ir_builder_.CreateStore( @@ -1097,8 +1098,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, } llvm::Value* input_feature = loops - .AddLoop(0, lhs->shape().dimensions(dnums.feature_dimension()), - "iz") + .AddLoop( + 0, lhs->shape().dimensions(dnums.input_feature_dimension()), + "iz") ->GetIndVarValue(); SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); @@ -1178,10 +1180,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, for (int i = 0; i < num_spatial_dims; ++i) { input_index[dnums.spatial_dimensions(i)] = input_spatial[i]; } - input_index[dnums.feature_dimension()] = input_feature; - input_index[dnums.batch_dimension()] = batch; + input_index[dnums.input_feature_dimension()] = input_feature; + input_index[dnums.input_batch_dimension()] = batch; - llvm_ir::IrArray kernel_array(GetIrArrayForOp(rhs)); + llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); llvm_ir::IrArray::Index kernel_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { kernel_index[dnums.kernel_spatial_dimensions(i)] = kernel_spatial[i]; @@ -1189,7 +1191,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; - llvm_ir::IrArray input_array(GetIrArrayForOp(lhs)); + llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); llvm::Value* product = ir_builder_.CreateFMul( input_array.EmitReadArrayElement(input_index, &ir_builder_), kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_)); @@ -1294,14 +1296,14 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { PrimitiveType element_type = operand->shape().element_type(); // Used to calculate E(X). llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(element_type, module_), "sum_address", &ir_builder_, MinimumAlignmentForPrimitiveType(element_type)); // Used to calculate E(X^2). llvm::Value* sum_square_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(element_type, module_), "sum_square_address", &ir_builder_, MinimumAlignmentForPrimitiveType(element_type)); @@ -1323,7 +1325,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm_ir::IrArray::Index input_index = FillReducedDimensionIndex(reduced_dims_index, index); llvm::Value* new_value = @@ -1367,9 +1369,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { mean_array, &ir_builder_) .EmitLoop(IrName(batch_norm_training, "mean_var"))); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(batch_norm_training)); - + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(batch_norm_training)); TF_ASSIGN_OR_RETURN( const BufferAllocation::Slice slice, assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0})); @@ -1399,7 +1399,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { llvm::Value* var = var_array.EmitReadArrayElement( feature_index_value, &ir_builder_); - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* input = operand_array.EmitReadArrayElement(index, &ir_builder_); @@ -1411,10 +1411,10 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { ir_builder_.CreateCall(func_llvm_sqrt, {variance_with_epsilon}); llvm::Value* normalized = ir_builder_.CreateFDiv( ir_builder_.CreateFSub(input, mean), variance_sqrt); - llvm_ir::IrArray offset_array(GetIrArrayForOp(offset)); + llvm_ir::IrArray offset_array(GetIrArrayFor(offset)); llvm::Value* offset = offset_array.EmitReadArrayElement( feature_index_value, &ir_builder_); - llvm_ir::IrArray scale_array(GetIrArrayForOp(scale)); + llvm_ir::IrArray scale_array(GetIrArrayFor(scale)); llvm::Value* scale = scale_array.EmitReadArrayElement( feature_index_value, &ir_builder_); llvm::Value* result = ir_builder_.CreateFAdd( @@ -1425,11 +1425,8 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { target_array, &ir_builder_) .EmitLoop(IrName(batch_norm_training, "normalize"))); - llvm_ir::EmitTuple( - llvm_ir::IrArray(target_address, batch_norm_training->shape()), - {normalized, mean, var}, &ir_builder_); - emitted_value_[batch_norm_training] = target_address; - + llvm_ir::EmitTuple(GetIrArrayFor(batch_norm_training), + {normalized, mean, var}, &ir_builder_, module_); return Status::OK(); } @@ -1457,13 +1454,19 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); llvm::LoadInst* param_address_untyped = ir_builder_.CreateLoad(param_address_offset); + param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); + if (hlo_module_config_.debug_options() + .xla_llvm_enable_invariant_load_metadata()) { + // We never reassign parameters, so this load is invariant. + param_address_untyped->setMetadata( + llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); + } + llvm::Value* param_address_typed = ir_builder_.CreateBitCast( param_address_untyped, IrShapeType(param_shape)->getPointerTo()); emitted_value_[parameter] = param_address_typed; - // Parameters of different types may not alias one another. - llvm_ir::SetTbaaForInstruction(param_address_untyped, param_shape, - /*is_pointer_to=*/true); if (!ShapeUtil::IsOpaque(param_shape)) { AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); @@ -1486,6 +1489,14 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( } const Shape& root_shape = root_instruction->shape(); + if (ShapeUtil::ElementIsComplex(root_shape)) { + // TODO(b/65408531): Complex add could by done via bitcast to + // Complex multiply would be more challenging. We could perhaps use a + // strided load to get all reals in a vector, all imags in a vector, or use + // CreateShuffleVector on a bitcast to float x [2N]. + *failure_reason = "complex values not supported"; + return nullptr; + } bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape); bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape); bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape); @@ -1507,7 +1518,7 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( // This is visually similar to ElementalIrEmitter, though conceptually we're // doing something different here. ElementalIrEmitter emits scalar operations // while these emit scalar or vector operations depending on the type of the - // operands. + // operands. See CreateShardedVectorType for the actual types in use here. switch (root_instruction->opcode()) { default: *failure_reason = "did not recognize root instruction opcode"; @@ -1527,11 +1538,11 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( : ir_builder->CreateFMul(lhs, rhs); }; - case HloOpcode::kLogicalAnd: + case HloOpcode::kAnd: return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, llvm::Value* rhs) { return ir_builder->CreateAnd(lhs, rhs); }; - case HloOpcode::kLogicalOr: + case HloOpcode::kOr: return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); }; @@ -1584,7 +1595,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( ShardedVectorType sharded_vector_type; llvm::Type* element_ir_type = - llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_); + llvm_ir::PrimitiveTypeToIrType(element_type, module_); for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) { // For every power of two present in element_count, we generate one or more @@ -1667,7 +1678,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); - llvm_ir::IrArray arg_array(GetIrArrayForOp(arg)); + llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = output_index.begin(); @@ -1780,6 +1791,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( } CHECK(!ShapeUtil::IsTuple(reduce->shape())); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce)); // We know we're not reducing over the most minor dimension, which means we // can lower the reduction loop as: @@ -1842,10 +1854,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( reduction_generator, array_index, vector_type, init_value, arg, dimensions, element_alignment)); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(reduce)); - llvm_ir::IrArray target_array(target_address, reduce->shape()); - AddAliasingInformationToIrArray(*reduce, &target_array); + llvm_ir::IrArray target_array = GetIrArrayFor(reduce); llvm::Value* output_address = target_array.EmitArrayElementAddress(array_index, &ir_builder_); EmitShardedVectorStore(output_address, accumulator, element_alignment, @@ -1877,10 +1886,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( reduction_generator, array_index, vector_type, init_value, arg, dimensions, element_alignment)); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(reduce)); - llvm_ir::IrArray target_array(target_address, reduce->shape()); - AddAliasingInformationToIrArray(*reduce, &target_array); + llvm_ir::IrArray target_array = GetIrArrayFor(reduce); llvm::Value* output_address = target_array.EmitArrayElementAddress(array_index, &ir_builder_); EmitShardedVectorStore(output_address, accumulator, element_alignment, @@ -1891,10 +1897,6 @@ StatusOr IrEmitter::EmitVectorizedReduce( ir_builder_.SetInsertPoint(outermost_loop_exit_block); } - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(reduce)); - - emitted_value_[reduce] = target_address; return true; } @@ -1926,7 +1928,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, // Initialize an accumulator with init_value. PrimitiveType accumulator_type = reduce->shape().element_type(); llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(accumulator_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", &ir_builder_, MinimumAlignmentForPrimitiveType(accumulator_type)); llvm::Value* init_value_addr = GetEmittedValueFor(init_value); @@ -1951,7 +1953,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, // filled in. We fill in the rest of the dimensions with induction // Value*s taken from 'index' which iterates over the target array. // See the high-level description in the XLA documentation for details. - llvm_ir::IrArray arg_array(GetIrArrayForOp(arg)); + llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); @@ -1994,9 +1996,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { return DefaultAction(slice); } - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(slice)); - emitted_value_[slice] = target_address; + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice)); if (ShapeUtil::HasZeroElements(slice->shape())) { return Status::OK(); @@ -2068,8 +2068,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { outer_dims.push_back(memcpy_dim); } - llvm_ir::IrArray target_array(target_address, slice->shape()); - AddAliasingInformationToIrArray(*slice, &target_array); + llvm_ir::IrArray target_array = GetIrArrayFor(slice); const int64 num_outer_loops = outer_dims.size(); llvm_ir::ForLoopNest loops(IrName(slice), &ir_builder_); @@ -2087,7 +2086,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); } - llvm_ir::IrArray source_array = GetIrArrayForOp(operand); + llvm_ir::IrArray source_array = GetIrArrayFor(operand); const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice( /*shape=*/slice->shape(), /*starts=*/slice->slice_starts(), /*strides=*/slice->slice_strides(), /*builder=*/&ir_builder_); @@ -2122,126 +2121,26 @@ Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice, HloInstruction* operand, HloInstruction* /*start_indices*/) { if (ShapeUtil::IsScalar(dynamic_slice->shape())) { - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(dynamic_slice)); - target_address->setName(AsStringRef(IrName(dynamic_slice))); - emitted_value_[dynamic_slice] = target_address; + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice)); return EmitMemcpy(*operand, *dynamic_slice); } return DefaultAction(dynamic_slice); } -namespace { - -// Returns the first non-GetTupleElement ancestor instruction of 'hlo'. -// If the first non-GTE ancestor is tuple-shaped, populates 'index' with the -// (possibly nested) tuple indices used on the path from ancestor to 'hlo'. -const HloInstruction* LatestNonGteAncestorAndIndex(const HloInstruction* hlo, - ShapeIndex* index) { - if (hlo->opcode() == HloOpcode::kGetTupleElement) { - const auto* operand = LatestNonGteAncestorAndIndex(hlo->operand(0), index); - index->push_back(hlo->tuple_index()); - return operand; - } - return hlo; -} - -// Checks if we can emit code for DynamicUpdateSlice to update data in-place. -// Returns true if operand 0 of DynamicUpdateSlice and its output buffer -// share the same buffer allocation. -// Returns false otherwise. -// TODO(b/64142684) Share code with GPU implementation. -bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment, - HloInstruction* dynamic_update_slice) { - CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode()); - - // Walk DynamicUpdateSlice operand(0) to parameter and get its - // associated operand. See if it shares an allocation with this operand. - ShapeIndex index; - auto* operand = - LatestNonGteAncestorAndIndex(dynamic_update_slice->operand(0), &index); - if (operand->opcode() != HloOpcode::kParameter) { - return false; - } - - BufferAllocation::Slice operand_slice = - assignment.GetUniqueSlice(operand, index).ConsumeValueOrDie(); - - BufferAllocation::Slice dynamic_update_slice_slice = - assignment.GetUniqueTopLevelSlice(dynamic_update_slice) - .ConsumeValueOrDie(); - - return operand_slice == dynamic_update_slice_slice; -} - -} // namespace - Status IrEmitter::HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(dynamic_update_slice)); - target_address->setName(AsStringRef(IrName(dynamic_update_slice))); - emitted_value_[dynamic_update_slice] = target_address; + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); return EmitMemcpy(*update, *dynamic_update_slice); - } else if (CanUpdateDynamicSliceInPlace(assignment_, dynamic_update_slice)) { - VLOG(2) << "Emitting HandleDynamicUpdateSlice in-place."; - // DynamicUpdateSlice's operand(0) and 'fusion' output share the same - // BufferAllocation::Slice, so it is safe to emit code to update the slice - // 'in-place'. This avoids copying data outside of the slice update region. - // TODO(b/64142684) Implement in-place update for fused DynamicUpdateSlice. - - // Emit IR to read dynamic start indices from 'start_indices'. - const int64 rank = ShapeUtil::Rank(operand->shape()); - llvm_ir::IrArray::Index start_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index({ir_builder_.getInt64(i)}); - llvm_ir::IrArray start_indices_array(GetIrArrayForOp(start_indices)); - start_index[i] = - start_indices_array.EmitReadArrayElement(dim_index, &ir_builder_); - } - - // Create loop body emitter which emits code to do the following: - // *) Map requested 'index' and slice 'start_index' to input/output shape - // as 'output_index'. - // *) Reads value from 'update'. - // *) Writes value to input/output array at 'output_index'. - auto loop_body_emitter = - [&](const llvm_ir::IrArray::Index& index) -> Status { - // Calculate 'output_index' at which to write value from update. - llvm_ir::IrArray::Index output_index(rank); - for (int64 i = 0; i < rank; ++i) { - // Emit IR which computes: - // output_index = (start_index + index) % dim_size - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), operand->shape().dimensions(i)); - llvm::Value* start_index0 = ir_builder_.CreateZExtOrBitCast( - start_index[i], index[i]->getType()); - output_index[i] = ir_builder_.CreateURem( - ir_builder_.CreateAdd(start_index0, index[i]), dim_size); - } - - // Read value from 'update'. - llvm_ir::IrArray update_array(GetIrArrayForOp(update)); - llvm::Value* update_data = - update_array.EmitReadArrayElement(index, &ir_builder_); - - // Write value to output array. - GetIrArrayForOp(operand).EmitWriteArrayElement(output_index, update_data, - &ir_builder_); - return Status::OK(); - }; - - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(loop_body_emitter, update->shape(), &ir_builder_) - .EmitLoop(IrName(dynamic_update_slice, "in_place"))); - - TF_ASSIGN_OR_RETURN(llvm::Value * dynamic_update_slice_address, - EmitTargetAddressForOp(dynamic_update_slice)); - emitted_value_[dynamic_update_slice] = dynamic_update_slice_address; - return Status::OK(); + } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice, + assignment_)) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); + auto operands = GetIrArraysForOperandsOf(dynamic_update_slice); + return llvm_ir::EmitDynamicUpdateSliceInPlace( + operands, GetIrArrayFor(dynamic_update_slice), + IrName(dynamic_update_slice, "in_place"), &ir_builder_); } return DefaultAction(dynamic_update_slice); } @@ -2283,7 +2182,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); // Load an element from the operand. - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); @@ -2303,7 +2202,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { } // Store the operand element to the computed output location. - llvm_ir::IrArray output_array(GetIrArrayForOp(pad)); + llvm_ir::IrArray output_array(GetIrArrayFor(pad)); output_array.EmitWriteArrayElement(output_index, operand_data, &ir_builder_); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); @@ -2319,11 +2218,11 @@ static const HloInstruction* StripTranspose(const HloInstruction& hlo) { } Status IrEmitter::HandleFusion(HloInstruction* fusion) { + auto* root = fusion->fused_expression_root(); if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) { - const HloInstruction* dot = fusion->fused_expression_root(); - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(root->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1)); DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && rhs_parameter->opcode() == HloOpcode::kParameter); const HloInstruction* lhs = @@ -2332,18 +2231,15 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { fusion->operand(rhs_parameter->parameter_number()); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*dot, /*operands=*/{lhs, rhs}, + /*instruction=*/*root, /*operands=*/{lhs, rhs}, /*supported_types=*/{F32})); - llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs)); - llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs)); + llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); + llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); Shape target_shape = fusion->shape(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(fusion)); - llvm_ir::IrArray target_array(target_address, target_shape); - AddAliasingInformationToIrArray(*fusion, &target_array); - + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); + llvm_ir::IrArray target_array = GetIrArrayFor(fusion); VLOG(2) << "HandleFusion kTransposeDot: "; VLOG(2) << " lhs operand: " << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); @@ -2354,19 +2250,27 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { // Dot operation is complicated so we delegate to a helper class. TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, dot->operand(0)->IsRank2Transpose(), - dot->operand(1)->IsRank2Transpose(), target_array, lhs_array, rhs_array, - GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_)); - - emitted_value_[fusion] = target_address; + *root, root->operand(0)->IsRank2Transpose(), + root->operand(1)->IsRank2Transpose(), target_array, lhs_array, + rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, + hlo_module_config_)); return Status::OK(); + } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, + assignment_)) { + VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); + + // Delegate to common implementation of fused in-place dynamic-update-slice. + auto operands = GetIrArraysForOperandsOf(fusion); + return llvm_ir::EmitFusedDynamicUpdateSliceInPlace( + fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, + &ir_builder_); } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { - std::vector parameter_arrays; - for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArrayForOp(operand)); - } + VLOG(3) << "HandleFusion kLoop"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + auto operands = GetIrArraysForOperandsOf(fusion); + FusedIrEmitter fused_emitter(operands, &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); @@ -2384,14 +2288,20 @@ Status IrEmitter::HandleCall(HloInstruction* call) { parameter_addresses.push_back(GetEmittedValueFor(operand)); } - TF_ASSIGN_OR_RETURN(llvm::Value * output_address, - EmitTargetAddressForOp(call)); - output_address->setName(AsStringRef(IrName(call))); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); - EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, - output_address, computation->name()); + if (!computation->root_instruction()->outer_dimension_partitions().empty() && + !parallel_cpu_backend_) { + // ParallelTaskAssignment assigned partitions, emit call to + // ParallelForkJoin. + TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses, + emitted_value_[call], computation, + call_ir_function)); + } else { + EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, + emitted_value_[call], computation->name()); + } - emitted_value_[call] = output_address; return Status::OK(); } @@ -2420,17 +2330,13 @@ Status IrEmitter::HandleCustomCall( /*Params=*/{i8_ptr_type, operands_alloca->getType()}, /*isVarArg=*/false))); - TF_ASSIGN_OR_RETURN(llvm::Value * output_address, - EmitTargetAddressForOp(custom_call)); - output_address->setName(AsStringRef(IrName(custom_call))); - - auto* output_address_arg = - ir_builder_.CreatePointerCast(output_address, i8_ptr_type); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); + auto* output_address_arg = ir_builder_.CreatePointerCast( + GetEmittedValueFor(custom_call), i8_ptr_type); ir_builder_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca}); - emitted_value_[custom_call] = output_address; return Status::OK(); } @@ -2505,8 +2411,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { {while_result}, IrName(xla_while, "cond")); llvm::Value* while_predicate = ir_builder_.CreateICmpNE( while_condition, - llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), - 0)); + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( @@ -2574,10 +2479,8 @@ StatusOr IrEmitter::EmitFastConcatenate( llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); llvm::Type* i8_type = ir_builder_.getInt8Ty(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(concatenate)); - - llvm_ir::IrArray target_array(target_address, output_shape); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); + llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); llvm_ir::ForLoopNest loops(IrName(concatenate), &ir_builder_); llvm_ir::IrArray::Index outer_dims_index = @@ -2594,8 +2497,6 @@ StatusOr IrEmitter::EmitFastConcatenate( unsigned primitive_type_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - AddAliasingInformationToIrArray(*concatenate, &target_array); - // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. llvm::Value* target_region_begin = ir_builder_.CreateBitCast( @@ -2614,7 +2515,7 @@ StatusOr IrEmitter::EmitFastConcatenate( // equal to the product of inner dimensions. for (HloInstruction* operand : operands) { const Shape& input_shape = operand->shape(); - llvm_ir::IrArray source_array = GetIrArrayForOp(operand); + llvm_ir::IrArray source_array = GetIrArrayFor(operand); llvm::Value* copy_source_address = ir_builder_.CreateBitCast( source_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_, "src_addr"), @@ -2638,8 +2539,6 @@ StatusOr IrEmitter::EmitFastConcatenate( SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); } - emitted_value_[concatenate] = target_address; - return true; } @@ -2653,7 +2552,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, unsigned element_alignment = GCD( primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)); llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual( - llvm_ir::PrimitiveTypeToIrType(primitive_type, &ir_builder_)); + llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); if (element_count == 1) { auto* load_instruction = ir_builder_.CreateAlignedLoad( @@ -2711,7 +2610,7 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { // For the parallel cpu backend, we record the total for each embedded // computation callee with its caller kCall HLO. HloInstruction* hlo_to_lookup = nullptr; - if (IsParallelContext()) { + if (parallel_cpu_backend_ && is_top_level_computation_) { auto* computation = root->parent(); auto* entry_computation = computation->parent()->entry_computation(); if (computation != entry_computation) { @@ -2839,7 +2738,7 @@ Status IrEmitter::Postprocess(HloInstruction* hlo) { return Status::OK(); } -llvm_ir::IrArray IrEmitter::GetIrArrayForOp(const HloInstruction* hlo) { +llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) { llvm::Value* value_for_op = GetEmittedValueFor(hlo); llvm_ir::IrArray array(value_for_op, hlo->shape()); @@ -2847,6 +2746,16 @@ llvm_ir::IrArray IrEmitter::GetIrArrayForOp(const HloInstruction* hlo) { return array; } +std::vector IrEmitter::GetIrArraysForOperandsOf( + const HloInstruction* hlo) { + std::vector arrays; + std::transform( + hlo->operands().begin(), hlo->operands().end(), + std::back_inserter(arrays), + [&](const HloInstruction* operand) { return GetIrArrayFor(operand); }); + return arrays; +} + llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) { auto it = emitted_value_.find(hlo); if (it == emitted_value_.end()) { @@ -2856,7 +2765,22 @@ llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) { } llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { - return llvm_ir::ShapeToIrType(shape, &ir_builder_); + return llvm_ir::ShapeToIrType(shape, module_); +} + +std::vector IrEmitter::GetComputeFunctionParams() { + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); + std::vector compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds_ > 0) { + compute_function_params.push_back(i64_ptr_type); + } + if (hlo_to_profile_idx_) { + compute_function_params.push_back(i64_ptr_type); + } + return compute_function_params; } llvm::Argument* IrEmitter::GetResultArgument() { @@ -2864,7 +2788,7 @@ llvm::Argument* IrEmitter::GetResultArgument() { } llvm::Argument* IrEmitter::GetProfileCountersArgument() { - const int64 arg_index = IsParallelContext() ? 5 : 4; + const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; return hlo_to_profile_idx_ ? GetArg(compute_function_, arg_index) : nullptr; } @@ -2915,14 +2839,12 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( ir_builder_.CreateLoad(tempbuf_address_ptr); if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // Loading the address of a buffer is invariant of the point at which the - // load is executed in the program because we never reassign buffers. + // Loading the address of a buffer is invariant of the point at which the + // load is executed in the program because we never reassign buffers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); } - llvm_ir::SetTbaaForInstruction(tempbuf_address_base, target_shape, - /*is_pointer_to=*/true); AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size()); AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size()); @@ -2949,18 +2871,11 @@ llvm::Value* IrEmitter::EmitElementFunctionCall( AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); } -// Emits a core function call based on the following pseudo-code. -// -// char** parameter_addresses_buffer = -// allocate buffer with a pointer for each parameter to the function -// for each parameter index, i.e. for i = 0, ..., #parameters: -// parameter_addresses_buffer[i] = parameter_addresses[i] -// call function(return_value_buffer, -// parameter_addresses_buffer, -// temps) -// return return_value_buffer -- address of the return value. -void IrEmitter::EmitArrayFunctionCallInto( - llvm::Function* function, +// Emits code to allocate an array of parameter address pointers, and store +// each address from 'parameter_addresses'. +// Returns an array of compute function call arguments (including parameter +// address buffer). +std::vector IrEmitter::GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice parameter_addresses, llvm::Value* return_value_buffer, tensorflow::StringPiece name) { llvm::Value* parameter_addresses_buffer = @@ -2989,7 +2904,26 @@ void IrEmitter::EmitArrayFunctionCallInto( if (auto* profile_counters = GetProfileCountersArgument()) { arguments.push_back(profile_counters); } - ir_builder_.CreateCall(function, arguments); + return arguments; +} + +// Emits a core function call based on the following pseudo-code. +// +// char** parameter_addresses_buffer = +// allocate buffer with a pointer for each parameter to the function +// for each parameter index, i.e. for i = 0, ..., #parameters: +// parameter_addresses_buffer[i] = parameter_addresses[i] +// call function(return_value_buffer, +// parameter_addresses_buffer, +// temps) +// return return_value_buffer -- address of the return value. +void IrEmitter::EmitArrayFunctionCallInto( + llvm::Function* function, + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name) { + ir_builder_.CreateCall( + function, GetArrayFunctionCallArguments(parameter_addresses, + return_value_buffer, name)); } llvm::Value* IrEmitter::EmitArrayFunctionCall( @@ -3001,7 +2935,7 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( PrimitiveType return_type = return_shape.element_type(); llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - llvm_ir::PrimitiveTypeToIrType(return_type, &ir_builder_), elements, + llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements, tensorflow::strings::StrCat(name, "_return_value_address"), &ir_builder_, MinimumAlignmentForPrimitiveType(return_type)); EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer, @@ -3009,10 +2943,114 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( return return_value_buffer; } -StatusOr IrEmitter::EmitTargetAddressForOp( - const HloInstruction* op, const ShapeIndex& shape_index) { - const Shape& target_shape = ShapeUtil::GetSubshape(op->shape(), shape_index); - if (op == op->parent()->root_instruction() && shape_index.empty()) { +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status IrEmitter::EmitParallelForkJoin( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* output_address, HloComputation* computation, + llvm::Function* parallel_function) { + HloInstruction* root = computation->root_instruction(); + + // Build ParallelForkJoin function type. + std::vector compute_function_params = GetComputeFunctionParams(); + // Number of parallel compute functions. + compute_function_params.push_back(ir_builder_.getInt32Ty()); + // Array of partitions. There is an array element for each + // partition x partition_dim x 2 (for dimension start and limit). + compute_function_params.push_back( + llvm::Type::getInt64PtrTy(module_->getContext())); + // Number of partitioned most-major dimensions in 'root.shape'. + compute_function_params.push_back(ir_builder_.getInt32Ty()); + // Function pointer for compute function to be dispatched in parallel. + compute_function_params.push_back( + llvm::Type::getInt8PtrTy(module_->getContext())); + + llvm::FunctionType* fork_join_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(module_->getContext()), + /*Params=*/compute_function_params, + /*isVarArg=*/false); + + llvm::Function* fork_join_func = + llvm::cast(module_->getOrInsertFunction( + runtime::kParallelForkJoinSymbolName, fork_join_type)); + fork_join_func->setCallingConv(llvm::CallingConv::C); + fork_join_func->setDoesNotThrow(); + + // Add common compute function arguments. + const string name = computation->name(); + std::vector arguments = + GetArrayFunctionCallArguments(parameter_addresses, output_address, name); + + // Create ShapePartitionIterator to generate all partitions of 'root.shape'. + ShapePartitionIterator partition_iterator(root->shape(), + root->outer_dimension_partitions()); + const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); + // Add argument specifying the number of parallel partitions. + arguments.push_back(ir_builder_.getInt32(num_partitions)); + + // The number of partitioned most-major dimensions in 'root.shape'. + const int32 num_partitioned_dims = root->outer_dimension_partitions().size(); + // A dimension partition consists of two elements: [start_index, limit_index). + const int32 dim_partition_size = 2; + // Calculate array partition stride. + const int32 array_partition_stride = + num_partitioned_dims * dim_partition_size; + // Calculate the total number of elements in the partition array. + const int32 partition_array_size = + dim_partition_size * num_partitioned_dims * num_partitions; + + // Store dimension partition values as llvm constants in 'partitions'. + // See comments in runtime_fork_join.cc for array layout description. + std::vector partitions(partition_array_size); + for (int32 i = 0; i < num_partitions; ++i) { + std::vector> dim_partitions = + partition_iterator.GetPartition(i); + CHECK_EQ(num_partitioned_dims, dim_partitions.size()); + const int32 partition_index = i * array_partition_stride; + for (int32 j = 0; j < num_partitioned_dims; ++j) { + const std::pair& dim_partition = dim_partitions[j]; + const int32 index = partition_index + j * dim_partition_size; + // Store partition [dim_start, dim_limit) intervals for each dimension. + partitions[index] = ir_builder_.getInt64(dim_partition.first); + partitions[index + 1] = + ir_builder_.getInt64(dim_partition.first + dim_partition.second); + } + } + + // Create global variable out of dimension partitions in 'partitions'. + llvm::ArrayType* partitions_array_type = + llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size); + llvm::Constant* partitions_array = + llvm::ConstantArray::get(partitions_array_type, partitions); + llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/partitions_array_type, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/partitions_array, + /*Name=*/ + AsStringRef( + tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + + // Add argument specifying parallel dimension partitions. + arguments.push_back(ir_builder_.CreateBitCast( + global_partitions_array, + llvm::Type::getInt64PtrTy(module_->getContext()))); + // Add argument specifying the number of partitioned most-major dimensions. + arguments.push_back(ir_builder_.getInt32(num_partitioned_dims)); + // Add argument for parallel compute function pointer. + arguments.push_back( + ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy())); + // Emit call to parallel fork/join. + ir_builder_.CreateCall(fork_join_func, arguments); + + return Status::OK(); +} + +Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { + llvm::Value* addr; + const Shape& target_shape = op->shape(); + if (op == op->parent()->root_instruction()) { // For the root node, we write directly to the output buffer of the // function. llvm::Argument* retval = GetResultArgument(); @@ -3022,15 +3060,18 @@ StatusOr IrEmitter::EmitTargetAddressForOp( attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); retval->addAttrs(attr_builder); } - return ir_builder_.CreateBitCast(retval, + addr = ir_builder_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo()); - } - - // For other nodes, we need the temporary buffer allocated for this node to - // write the result into. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(op)); - return EmitTempBufferPointer(slice, target_shape); + } else { + // For other nodes, we need the temporary buffer allocated for this node to + // write the result into. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + addr = EmitTempBufferPointer(slice, target_shape); + } + addr->setName(AsStringRef(IrName(op))); + emitted_value_[op] = addr; + return Status::OK(); } Status IrEmitter::EmitTargetElementLoop( @@ -3044,12 +3085,9 @@ Status IrEmitter::EmitTargetElementLoop( const llvm_ir::ElementGenerator& element_generator) { VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); - // target_address will hold the address of the target buffer we will write the - // result of the computation into. const Shape& target_shape = target_op->shape(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(target_op)); - VLOG(2) << " target address: " << llvm_ir::DumpToString(*target_address); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op)); + llvm_ir::IrArray target_array = GetIrArrayFor(target_op); if (target_op->IsMultiOutputFusion()) { // For multiple outputs fusion, we need to emit each operand and the root. @@ -3072,13 +3110,9 @@ Status IrEmitter::EmitTargetElementLoop( for (int64 i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } - llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, target_shape), - tuple_operand_ptrs, &ir_builder_); + llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_, module_); } else { - llvm_ir::IrArray target_array(target_address, target_shape); - AddAliasingInformationToIrArray(*target_op, &target_array); - if (ShouldEmitParallelLoopFor(*target_op)) { TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( target_shape, element_generator, IrName(target_op), &target_array)); @@ -3088,8 +3122,6 @@ Status IrEmitter::EmitTargetElementLoop( .EmitLoop(IrName(target_op))); } } - - emitted_value_[target_op] = target_address; return Status::OK(); } @@ -3181,7 +3213,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArrayFor(operand).EmitReadArrayElement(index, &ir_builder_); }; } CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 8042e03e69561aeacccc5498eaf52f32bbd78b62..58c185af1ec6e00c854b39f77281da7760855fe0 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/IR/Value.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -104,11 +105,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { // llvm_module: the LLVM module to emit IR into. // hlo_to_profile_idx: the mapping from HLO to its index in the profiling // array. + // external_constant_pool: if non-null, points to an ExternalConstantPool + // instance into which the Ir emitter can spill + // constants. IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, const std::unordered_map* hlo_to_profile_idx, - llvm::TargetMachine* target_machine); + llvm::TargetMachine* target_machine, + ExternalConstantPool* external_constant_pool); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -146,7 +151,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { // // Default action which emits code for most operations. Operations which are // special in some way are handled explicitly in HandleFoo methods. - Status DefaultAction(HloInstruction* hlo_instruction) override; + Status DefaultAction(HloInstruction* hlo) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleConstant(HloInstruction* constant, @@ -175,7 +180,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleReduceWindow(HloInstruction* reduce_window, HloInstruction* operand, const Window& window, HloComputation* function) override; - Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override; Status HandleSend(HloInstruction* send) override; Status HandleSlice(HloInstruction* slice, HloInstruction* /*operand*/) override; @@ -208,7 +213,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; - Status Postprocess(HloInstruction* visited) override; + Status Postprocess(HloInstruction* hlo) override; private: // Private helper to initialize an IR function for the computation. @@ -220,8 +225,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Gets the IR Value emitted previously for the given hlo. // - // Prefer calling GetIrArrayForOp if the value you're reading is a buffer, - // because GetIrArrayForOp annotates buffer's loads/stores with noalias + // Prefer calling GetIrArrayFor if the value you're reading is a buffer, + // because GetIrArrayFor annotates buffer's loads/stores with noalias // metadata. // // Make sure to call this only when you're certain a value *was* emitted - if @@ -229,7 +234,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* GetEmittedValueFor(const HloInstruction* hlo); // Gets an IrArray representing the given hlo. - llvm_ir::IrArray GetIrArrayForOp(const HloInstruction* hlo); + llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo); + + // Gets a list of IrArrays, one for each of hlo's operands. + std::vector GetIrArraysForOperandsOf( + const HloInstruction* hlo); // Augments IrArray with aliasing information. void AddAliasingInformationToIrArray(const HloInstruction& hlo, @@ -240,6 +249,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to get the IR type matching the given shape. llvm::Type* IrShapeType(const Shape& shape); + // Returns an array of compute function parameter types. + std::vector GetComputeFunctionParams(); + // Get the llvm::Value* that represents the "retval" argument of the // computation function being emitted by this emitter. llvm::Argument* GetResultArgument(); @@ -304,7 +316,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { void EmitArrayFunctionCallInto( llvm::Function* function, tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value, tensorflow::StringPiece name); + llvm::Value* return_value_buffer, tensorflow::StringPiece name); // Array function call emitter. Returns a Value for the function's return // value buffer address. The return value buffer is alloca'ed by this @@ -314,6 +326,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name); + // Returns an array of compute function call arguments. + std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name); + + // Emits a call to a runtime fork/join function which dispatches parallel + // calls to 'parallel_function' (and joins threads before returning). + Status EmitParallelForkJoin( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* output_address, HloComputation* computation, + llvm::Function* parallel_function); + // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. Status ElementTypesSameAndSupported( @@ -353,11 +377,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitMemcpy(const HloInstruction& source, const HloInstruction& destination); - // Emit IR to compute the target address of the buffer for the given op. - // The returned Value is a pointer to a IR type that represents the op's - // element type. - StatusOr EmitTargetAddressForOp( - const HloInstruction* op, const ShapeIndex& shape_index = {}); + // Emits IR to compute the target address of the buffer for the given op. + // After calling this function, you can get a pointer to this buffer by + // calling GetIrArrayForOp or GetEmittedValueFor. + Status EmitTargetAddressForOp(const HloInstruction* op); // Structurizes "array_elements" into an MD array that represents "shape". // This is a recursive function, and "dimension_index" indicates the index of @@ -447,10 +470,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array); - // Name of the computation entry function. This function serves as the - // top-level "main" of the computation and will be invoked by the JIT. - string entry_function_name_; - // Assignment of the temporary buffers needed by the computation and their // shape information. const BufferAssignment& assignment_; @@ -592,12 +611,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); - // Returns true if the current function being emitted is called in a - // parallel context (returns false otherwise). - bool IsParallelContext() { - return parallel_cpu_backend_ && is_top_level_computation_; - } - const HloModuleConfig& hlo_module_config_; const bool parallel_cpu_backend_; @@ -606,6 +619,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { TargetMachineFeatures target_machine_features_; + int64 external_global_constant_counter_ = 0; + ExternalConstantPool* external_constant_pool_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index d4b5e41f5090856d68135643c1d2ee94c27491db..c2213c8f2ef592c537daf9abe2ffa10b83a8fa4c 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -48,29 +48,56 @@ class SimpleCostModel : public ParallelCostModel { class DefaultCostModel : public ParallelCostModel { public: DefaultCostModel(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, std::unique_ptr cost_analysis) : max_parallelism_(max_parallelism), + shape_size_(shape_size), cost_analysis_(std::move(cost_analysis)) {} ~DefaultCostModel() override {} int64 GetParallelTaskCount(HloInstruction* instruction) override { - // Calculate the instruction cost in cycles. - // TODO(29630486) Improve on this linear cost model. - // Consider making 'min_cost_per_thread' be a function of the target - // bandwidth limit for instructions with low arithmetic complexity. - const int64 instruction_cost = - 1 * cost_analysis_->flop_count(*instruction) + - 2 * cost_analysis_->transcendental_count(*instruction) + - 10 * cost_analysis_->bytes_accessed(*instruction); - // Minimum per-thread cost is 100us of work on a 2GHz core. - const int64 min_cost_per_thread = 100000; + // Parameters for parallel task count computation. + int64 instruction_cost; + int64 min_cost_per_thread; + int64 max_parallelism; + // Calculate flops-to-bytes-ratio for 'instruction'. + const int64 bytes_accessed = + std::max(1LL, cost_analysis_->bytes_accessed(*instruction)); + const float flops_to_bytes_ratio = + cost_analysis_->flop_count(*instruction) / + static_cast(bytes_accessed); + // Check for I/O bound instructions. + if (flops_to_bytes_ratio <= 1.0) { + // Limit max parallelism for I/O bound instructions by assuming a + // sub-linear scaling function (fit based on empirical benchmark results). + // TODO(29630486) Develop system bandwidth model. + max_parallelism = + std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs())); + // Use shape size instruction cost and L2 cache size min per-thread cost. + instruction_cost = shape_size_(instruction->shape()); + min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. + } else { + // Use max parallelism for compute bound instructions. + max_parallelism = max_parallelism_; + // Calculate the instruction cost in cycles. + // TODO(29630486) Improve on this linear cost model. + // Consider making 'min_cost_per_thread' be a function of the target + // bandwidth limit for instructions with low arithmetic complexity. + instruction_cost = + 1 * cost_analysis_->flop_count(*instruction) + + 2 * cost_analysis_->transcendental_count(*instruction) + + 10 * cost_analysis_->bytes_accessed(*instruction); + // Minimum per-thread cost is 100us of work on a 2GHz core. + min_cost_per_thread = 100000; + } // Return target parallel task count in [1, max_parallelism_]. - return std::min(max_parallelism_, + return std::min(max_parallelism, std::max(1LL, instruction_cost / min_cost_per_thread)); } private: const int64 max_parallelism_; + const HloCostAnalysis::ShapeSizeFunction shape_size_; const std::unique_ptr cost_analysis_; }; @@ -86,7 +113,7 @@ ParallelTaskAssignment::ParallelTaskAssignment( Status status = computation->root_instruction()->Accept(cost_analysis.get()); if (status.ok()) { // Set default cost model based on 'cost_analysis'. - cost_model_.reset(new DefaultCostModel(max_parallelism, + cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size, std::move(cost_analysis))); } else { // Fall back to a simple cost model based on hlo size and L2 cache size. @@ -109,6 +136,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kCustomCall || instruction->opcode() == HloOpcode::kSelectAndScatter || + instruction->opcode() == HloOpcode::kGetTupleElement || + instruction->opcode() == HloOpcode::kBitcast || (instruction->opcode() == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) || PotentiallyImplementedAsEigenDot(*instruction) || @@ -121,5 +150,102 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( return cost_model_->GetParallelTaskCount(instruction); } +StatusOr ParallelTaskAssigner::Run(HloModule* module) { + XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY"); + XLA_VLOG_LINES(3, module->ToString()); + + // Compute target parallel task counts for all instructions in 'module'. + HloToParallelTasks hlo_to_parallel_tasks; + ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks); + + // Assign parallel tasks to target specific instructions in 'module'. + // TODO(b/27458679) Support inter-op parallelism. + bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks); + + XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT"); + XLA_VLOG_LINES(3, module->ToString()); + return changed; +} + +bool ParallelTaskAssigner::AssignParallelTasks( + HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) { + return AssignParallelTasksHelper(module, module->entry_computation(), + hlo_to_parallel_tasks); +} + +bool ParallelTaskAssigner::AssignParallelTasksHelper( + HloModule* module, HloComputation* computation, + const HloToParallelTasks& hlo_to_parallel_tasks) { + bool changed = false; + // Snapshot set of instructions because outlining modifies the set below. + std::vector instructions(computation->instructions().begin(), + computation->instructions().end()); + for (auto* instruction : instructions) { + // Assign parallel tasks to sub-computations for While and Call HLOs. + // TODO(b/27458679) Evaluate alternative intra-op parallelsim placement, + // and support other callable computations like reduce. + if (instruction->opcode() == HloOpcode::kWhile) { + changed |= AssignParallelTasksHelper(module, instruction->while_body(), + hlo_to_parallel_tasks); + continue; + } else if (instruction->opcode() == HloOpcode::kCall) { + changed |= AssignParallelTasksHelper(module, instruction->to_apply(), + hlo_to_parallel_tasks); + continue; + } + // Skip if no parallel tasks were computed in first pass. + auto it = hlo_to_parallel_tasks.find(instruction); + if (it == hlo_to_parallel_tasks.end()) { + continue; + } + // Get target parallel task count computed for 'instruction'. + const int64 target_parallel_task_count = (*it).second; + // Assign feasible dimension partitions (based on actual dimension sizes). + auto dim_partition_counts = ShapePartitionAssigner(instruction->shape()) + .Run(target_parallel_task_count); + const int64 total_partition_count = + ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts); + if (total_partition_count <= 1) { + // Feasible partition calculation resulting in no partitioning, so skip. + continue; + } + + // Outline 'instruction' in 'computation' for parallel task assignment. + auto* call = module->OutlineExpressionFromComputation( + {instruction}, + tensorflow::strings::StrCat("parallel_", instruction->name()), + computation); + + // Set assigned dimension partitioning to 'instruction'. + auto* new_root = call->to_apply()->root_instruction(); + new_root->set_outer_dimension_partitions(dim_partition_counts); + + VLOG(2) << "Assigned parallel task count: " << total_partition_count + << " to instruction: " << new_root->name() + << " parent: " << new_root->parent()->name(); + changed = true; + } + return changed; +} + +void ParallelTaskAssigner::ComputeTargetParallelTasks( + HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { + // Compute parallel task counts for all instructions in 'module'. + for (auto* computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + for (auto* instruction : computation->instructions()) { + // Query ParallelTaskAssignment for target parallel task count. + const int64 target_parallel_task_count = + parallel_task_assignment_.GetTargetParallelTaskCount(instruction); + if (target_parallel_task_count > 1) { + hlo_to_parallel_tasks->insert( + {instruction, target_parallel_task_count}); + } + } + } +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 15f065a3ad44b39819a62bc0447785596a3bd29c..e036da5784f6151eb3b01107ec7f3ab820071a60 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { namespace cpu { @@ -49,6 +50,54 @@ class ParallelTaskAssignment { std::unique_ptr cost_model_; }; +// ParallelTaskAssigner computes target parallel task counts for all HLOs +// in the module, then assigns parallel task counts to HLOs in the entry +// computation, or to HLOs in embedded computations invoked by (potentially +// nested) kWhile or kCall instructions. +// Each HLO which is assigned parallel task counts is outlined into its +// own embedded computation, which is compiled as a parallel compute function, +// and which is invoked from a kCall instruction that is lowered in codegen to +// a runtime parallel fork/join call. +class ParallelTaskAssigner : public HloPassInterface { + public: + // 'max_parallelism': the maximum parallel task count per instruction. + // 'shape_size': shape size function used by HloCostAnalysis during parallel + // task assignment. + // 'module': the containing HloModule. + ParallelTaskAssigner(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + HloModule* module) + : parallel_task_assignment_(max_parallelism, shape_size, module) {} + ~ParallelTaskAssigner() override {} + + tensorflow::StringPiece name() const override { + return "cpu-parallel-task-assigner"; + } + + // Run parallel task assigner on 'module'. + // Returns true if the computation was changed, false otherwise. + StatusOr Run(HloModule* module) override; + + private: + using HloToParallelTasks = std::unordered_map; + + // Assigns target parallel tasks from 'hlo_to_parallel_tasks' to HLOs in + // 'module'. + // Returns true if the computation was changed, false otherwise. + bool AssignParallelTasks(HloModule* module, + const HloToParallelTasks& hlo_to_parallel_tasks); + bool AssignParallelTasksHelper( + HloModule* module, HloComputation* computation, + const HloToParallelTasks& hlo_to_parallel_tasks); + + // Computes target parallel task counts (returned in 'parallel_task_counts') + // for parallelizable instructions in 'module'. + void ComputeTargetParallelTasks(HloModule* module, + HloToParallelTasks* hlo_to_parallel_tasks); + + ParallelTaskAssignment parallel_task_assignment_; +}; + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc new file mode 100644 index 0000000000000000000000000000000000000000..d03da46575b331de113cc5f33c2b4267504e8308 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; +using tensorflow::uint64; + +using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, + int64*, uint64*); + +// Dispatches 'num_partitions - 1' calls to 'function_ptr' in parallel. +// Calls 'function_ptr' for first partition inline. +// Uses blocking counter to synchonize threads after parallel calls complete. +// +// The 'partitions' array has a total number of elements equal to +// 'num_partitions * num_partitioned_dims * 2' (the '2' is necessary to specify +// dimension start and limit indices). +// +// The 'partitions' array layout stores array elements in memory with dimension +// start limit as the most-minor dimension, followed by dimension, then +// partition. +// +// EX: Layout of 'partitions' array with 'num_partitions = 2', and +// 'num_partitioned_dims = 3' +// +// [partition0_dim0_start] +// [partition0_dim0_limit] +// [partition0_dim1_start] +// [partition0_dim1_limit] +// [partition0_dim2_start] +// [partition0_dim2_limit] +// [partition1_dim0_start] +// [partition1_dim0_limit] +// [partition1_dim1_start] +// [partition1_dim1_limit] +// [partition1_dim2_start] +// [partition1_dim2_limit] +// +void __xla_cpu_runtime_ParallelForkJoin( + void* result_ptr, const void* run_options_ptr, const void** params, + void** temps, uint64* prof_counters, int32 num_partitions, + int64* partitions, int32 num_partitioned_dims, void* function_ptr) { + VLOG(2) << "ParallelForkJoin ENTRY" + << " num_partitions: " << num_partitions + << " num_partitioned_dims: " << num_partitioned_dims; + CHECK_GT(num_partitions, 1); + CHECK_GT(num_partitioned_dims, 0); + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + ComputeFunctionType function = + reinterpret_cast(function_ptr); + // Compute partition stride in 'partitions' array. + const int64 stride = 2 * num_partitioned_dims; + + // Dispatch 'num_partitions - 1' compute functions to run in parallel. + tensorflow::BlockingCounter bc(num_partitions - 1); + for (int32 i = 1; i < num_partitions; ++i) { + const int64 offset = i * stride; + run_options->intra_op_thread_pool()->enqueueNoNotification( + [i, function, result_ptr, run_options_ptr, params, temps, prof_counters, + partitions, offset, &bc]() { + function(result_ptr, run_options_ptr, params, temps, + &partitions[offset], prof_counters); + bc.DecrementCount(); + VLOG(3) << "ParallelForkJoin partition " << i << " done."; + }); + } + + // Call first compute function inline. + function(result_ptr, run_options_ptr, params, temps, &partitions[0], + prof_counters); + VLOG(3) << "ParallelForkJoin partition 0 done."; + bc.Wait(); + VLOG(2) << "ParallelForkJoin EXIT"; +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h new file mode 100644 index 0000000000000000000000000000000000000000..fcf1cc62078d3847435a2e75e3ca9d109cf8b200 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// Dispatches 'num_partitions' parallel calls to 'function_ptr' and joins +// threads before returning. See comments in runtime_fork_join.cc for details. +extern void __xla_cpu_runtime_ParallelForkJoin( + void* result_ptr, const void* run_options_ptr, const void** params, + void** temps, tensorflow::uint64* prof_counters, + tensorflow::int32 num_partitions, tensorflow::int64* partitions, + tensorflow::int32 num_partitioned_dims, void* function_ptr); + +} // extern "C" + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index c3c11df090e88c3c24104b66d28b3b16f03baa80..fdf02e5b422f75e256feec77470bb0d079e8ef1f 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -31,7 +31,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -42,90 +44,21 @@ namespace xla { namespace cpu { namespace { -// Converts a symbol 'name' into the form expected by dlsym(). -std::string CanonicalizeSymbol(const std::string& name) { -#if defined(__APPLE__) - // On Mac OS X, dlsym() expects names not to be prefixed with a leading - // underscore. - if (!name.empty() && name.front() == '_') { - return name.substr(1); - } -#endif - return name; -} - -class JITSymbolTable { +// A simple SymbolResolver that delegates to the host dynamic linker. +class SimpleResolver : public llvm::JITSymbolResolver { public: - JITSymbolTable() { Populate(); } - - void* Lookup(llvm::StringRef jit_symbol_name) const { - auto it = jit_symbol_table_.find(jit_symbol_name); - return it == jit_symbol_table_.end() ? nullptr : it->getValue(); - } - - static bool MustBeInTable(llvm::StringRef name) { - // In particular, names starting with - // runtime::kXlaCpuRuntimeSymbolNamePrefix should not be dlsym'ed. - return name.startswith(runtime::kXlaCpuRuntimeSymbolNamePrefix); - } - - private: - void AddJITSymbolToTable(llvm::StringRef jit_symbol_name, - llvm::StringRef cpp_symbol_name, - void* jit_symbol_value) { - // The JIT symbol name and the C++ symbol name (with an extern "C" linkage) - // need to match, otherwise AOT links will fail. - CHECK(jit_symbol_name == cpp_symbol_name); - CHECK(jit_symbol_table_.insert({jit_symbol_name, jit_symbol_value}).second); - } - - void Populate() { -#define ADD_JIT_SYMBOL_TO_TABLE(base_name) \ - do { \ - AddJITSymbolToTable( \ - xla::cpu::runtime::k##base_name##SymbolName, \ - "__xla_cpu_runtime_" #base_name, \ - reinterpret_cast(__xla_cpu_runtime_##base_name)); \ - } while (false) - - ADD_JIT_SYMBOL_TO_TABLE(AcquireInfeedBufferForDequeue); - ADD_JIT_SYMBOL_TO_TABLE(ReleaseInfeedBufferAfterDequeue); - ADD_JIT_SYMBOL_TO_TABLE(AcquireOutfeedBufferForPopulation); - ADD_JIT_SYMBOL_TO_TABLE(ReleaseOutfeedBufferAfterPopulation); - ADD_JIT_SYMBOL_TO_TABLE(ExpV8F32AVX); - ADD_JIT_SYMBOL_TO_TABLE(LogV8F32AVX); - ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32SSE); - ADD_JIT_SYMBOL_TO_TABLE(LogV4F32SSE); - ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32NEON); - ADD_JIT_SYMBOL_TO_TABLE(LogV4F32NEON); - ADD_JIT_SYMBOL_TO_TABLE(EigenConvF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF64); - ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedConvF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF64); - -#undef ADD_JIT_SYMBOL_TO_TABLE - } - - llvm::StringMap jit_symbol_table_; -}; - -const JITSymbolTable& GetJITSymbolTable() { - static JITSymbolTable* symbol_table = new JITSymbolTable; - return *symbol_table; -} + explicit SimpleResolver(ExternalConstantPool* external_constant_pool) + : external_constant_pool_(external_constant_pool) {} -// A simple SymbolResolver that delegates to the host dynamic linker. -struct SimpleResolver : public llvm::JITSymbolResolver { llvm::JITSymbol findSymbol(const std::string& name) override { - std::string canonical_name = CanonicalizeSymbol(name); - const JITSymbolTable& jit_symbol_table = GetJITSymbolTable(); - - void* func_addr = JITSymbolTable::MustBeInTable(canonical_name) - ? jit_symbol_table.Lookup(canonical_name) - : dlsym(RTLD_DEFAULT, canonical_name.c_str()); + if (const uint8* from_constant_pool = + external_constant_pool_->Find(string(name))) { + return llvm::JITEvaluatedSymbol( + reinterpret_cast(from_constant_pool), + llvm::JITSymbolFlags::None); + } + void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); if (func_addr == nullptr) { return nullptr; } @@ -136,6 +69,9 @@ struct SimpleResolver : public llvm::JITSymbolResolver { llvm::JITSymbol findSymbolInLogicalDylib(const std::string& name) override { return nullptr; } + + private: + ExternalConstantPool* external_constant_pool_; }; llvm::SmallVector DetectMachineAttributes() { @@ -205,7 +141,7 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule( std::unique_ptr module) { auto handle = cantFail(compile_layer_.addModule( - std::move(module), MakeUnique())); + std::move(module), MakeUnique(external_constant_pool()))); module_handles_.push_back(handle); return handle; } @@ -238,5 +174,118 @@ llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string& name) { return nullptr; } +namespace { +// Register some known symbols with the CustomCallTargetRegistry. +bool RegisterKnownJITSymbols() { + CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global(); + +#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ + do { \ + auto* function_address = \ + reinterpret_cast(__xla_cpu_runtime_##base_name); \ + registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ + function_address); \ + CHECK_EQ( \ + tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \ + "__xla_cpu_runtime_" #base_name); \ + } while (false) + + REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); + REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32NEON); + REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE); + REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX); + REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON); + REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE); + REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX); + REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); + REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); + REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); + +#undef REGISTER_CPU_RUNTIME_SYMBOL + +#define REGISTER_LIBM_SYMBOL(name) \ + do { \ + /* Register both the F32 and F64 variants of the libm symbol. */ \ + registry->Register(#name "f", reinterpret_cast(name##f)); \ + registry->Register(#name, reinterpret_cast(name)); \ + } while (false) + + REGISTER_LIBM_SYMBOL(acos); + REGISTER_LIBM_SYMBOL(acosh); + REGISTER_LIBM_SYMBOL(asin); + REGISTER_LIBM_SYMBOL(asinh); + REGISTER_LIBM_SYMBOL(atan); + REGISTER_LIBM_SYMBOL(atan2); + REGISTER_LIBM_SYMBOL(atanh); + REGISTER_LIBM_SYMBOL(cbrt); + REGISTER_LIBM_SYMBOL(ceil); + REGISTER_LIBM_SYMBOL(copysign); + REGISTER_LIBM_SYMBOL(cos); + REGISTER_LIBM_SYMBOL(cosh); + REGISTER_LIBM_SYMBOL(erf); + REGISTER_LIBM_SYMBOL(erfc); + REGISTER_LIBM_SYMBOL(exp); + REGISTER_LIBM_SYMBOL(exp2); + REGISTER_LIBM_SYMBOL(expm1); + REGISTER_LIBM_SYMBOL(fabs); + REGISTER_LIBM_SYMBOL(fdim); + REGISTER_LIBM_SYMBOL(floor); + REGISTER_LIBM_SYMBOL(fma); + REGISTER_LIBM_SYMBOL(fmax); + REGISTER_LIBM_SYMBOL(fmin); + REGISTER_LIBM_SYMBOL(fmod); + REGISTER_LIBM_SYMBOL(frexp); + REGISTER_LIBM_SYMBOL(hypot); + REGISTER_LIBM_SYMBOL(ilogb); + REGISTER_LIBM_SYMBOL(ldexp); + REGISTER_LIBM_SYMBOL(lgamma); + REGISTER_LIBM_SYMBOL(llrint); + REGISTER_LIBM_SYMBOL(llround); + REGISTER_LIBM_SYMBOL(log); + REGISTER_LIBM_SYMBOL(log10); + REGISTER_LIBM_SYMBOL(log1p); + REGISTER_LIBM_SYMBOL(log2); + REGISTER_LIBM_SYMBOL(logb); + REGISTER_LIBM_SYMBOL(lrint); + REGISTER_LIBM_SYMBOL(lround); + REGISTER_LIBM_SYMBOL(modf); + REGISTER_LIBM_SYMBOL(nan); + REGISTER_LIBM_SYMBOL(nearbyint); + REGISTER_LIBM_SYMBOL(nextafter); + REGISTER_LIBM_SYMBOL(nexttoward); + REGISTER_LIBM_SYMBOL(pow); + REGISTER_LIBM_SYMBOL(remainder); + REGISTER_LIBM_SYMBOL(remquo); + REGISTER_LIBM_SYMBOL(rint); + REGISTER_LIBM_SYMBOL(round); + REGISTER_LIBM_SYMBOL(scalbln); + REGISTER_LIBM_SYMBOL(scalbn); + REGISTER_LIBM_SYMBOL(sin); + REGISTER_LIBM_SYMBOL(sincos); + REGISTER_LIBM_SYMBOL(sinh); + REGISTER_LIBM_SYMBOL(sqrt); + REGISTER_LIBM_SYMBOL(tan); + REGISTER_LIBM_SYMBOL(tanh); + REGISTER_LIBM_SYMBOL(tgamma); + REGISTER_LIBM_SYMBOL(trunc); + +#undef REGISTER_LIBM_SYMBOL + + registry->Register("memcpy", reinterpret_cast(memcpy)); + registry->Register("memmove", reinterpret_cast(memmove)); + registry->Register("memset", reinterpret_cast(memset)); + return true; +} + +bool unused = RegisterKnownJITSymbols(); +} // namespace + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index e476c0e3812cc0fb2a2d633832374b3165ca072a..ded01e9e4d7442296f7406dd035e6ab385458238 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -90,6 +91,10 @@ class SimpleOrcJIT { llvm::TargetMachine* target_machine() const { return target_machine_.get(); } + ExternalConstantPool* external_constant_pool() { + return &external_constant_pool_; + } + private: std::vector module_handles_; std::unique_ptr target_machine_; @@ -97,6 +102,7 @@ class SimpleOrcJIT { const llvm::DataLayout data_layout_; ObjLayerT object_layer_; CompileLayerT compile_layer_; + ExternalConstantPool external_constant_pool_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/defuser.cc b/tensorflow/compiler/xla/service/defuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..d124f74d19d83269be96ee34a6b4b2a8d00a978f --- /dev/null +++ b/tensorflow/compiler/xla/service/defuser.cc @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/defuser.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +namespace { + +// Copy all the instructions in the given fusion instruction into the fusion +// instruction's parent computation and replace the use of the fusion +// instruction with the copy of the fusion expression root. +Status Defuse(HloInstruction* fusion_instruction) { + VLOG(2) << "Defusing instruction: " << fusion_instruction->ToString(); + + HloComputation* fused_computation = + fusion_instruction->fused_instructions_computation(); + + // A map from fused instruction to its defused clone. + tensorflow::gtl::FlatMap + defused_instructions; + // Initialize map to contain the fusion instruction parameters mapping + // to the operands of the fusion instruction. + for (int64 i = 0; i < fusion_instruction->operand_count(); ++i) { + defused_instructions[fused_computation->parameter_instruction(i)] = + fusion_instruction->mutable_operand(i); + } + + // Create a clone of each instruction of the fused computation in the same + // computation as the fusion instruction itself. + // TODO(b/68227302): Moving instruction to new computation rather than + // cloning and deleting. + for (HloInstruction* fused_instruction : + fused_computation->MakeInstructionPostOrder()) { + if (fused_instruction->opcode() == HloOpcode::kParameter) { + continue; + } + std::vector new_operands; + for (HloInstruction* operand : fused_instruction->operands()) { + new_operands.push_back(defused_instructions.at(operand)); + } + HloInstruction* defused_instruction = + fusion_instruction->parent()->AddInstruction( + fused_instruction->CloneWithNewOperands(fused_instruction->shape(), + new_operands)); + defused_instructions[fused_instruction] = defused_instruction; + } + + TF_RETURN_IF_ERROR(fusion_instruction->ReplaceAllUsesWith( + defused_instructions.at(fusion_instruction->fused_expression_root()))); + + HloModule* module = fusion_instruction->parent()->parent(); + TF_RETURN_IF_ERROR( + fusion_instruction->parent()->RemoveInstruction(fusion_instruction)); + return module->RemoveEmbeddedComputation(fused_computation); +} + +} // namespace + +StatusOr Defuser::Run(HloModule* module) { + VLOG(1) << "Defusing module " << module->name(); + XLA_VLOG_LINES(2, "Before defusion:\n" + module->ToString()); + + bool changed = false; + std::unique_ptr call_graph = CallGraph::Build(module); + TF_RETURN_IF_ERROR(call_graph->VisitNodes( + [&](const CallGraphNode& call_graph_node) -> Status { + if (call_graph_node.computation()->IsFusionComputation()) { + TF_RET_CHECK(call_graph_node.caller_callsites().size() == 1); + HloInstruction* fusion_instruction = + call_graph_node.caller_callsites()[0].instruction(); + TF_RETURN_IF_ERROR(Defuse(fusion_instruction)); + changed = true; + } + return Status::OK(); + }, + /*visit_unreachable_nodes=*/true)); + + XLA_VLOG_LINES(2, "After defusion:\n" + module->ToString()); + + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h new file mode 100644 index 0000000000000000000000000000000000000000..56b28fd22da1ea6bc19f98e76f0f2ef4044cd3af --- /dev/null +++ b/tensorflow/compiler/xla/service/defuser.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DEFUSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DEFUSER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which replaces all fusion instructions with the equivalent un-fused +// instructions. +class Defuser : public HloPassInterface { + public: + Defuser() {} + ~Defuser() override {} + tensorflow::StringPiece name() const override { return "defuser"; } + + // Run defusion on the given module. Returns whether the module was + // changed. + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DEFUSER_H_ diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..32b5c5d35fae61ae6cb17fafcada1abd6c3c088c --- /dev/null +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -0,0 +1,214 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/defuser.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class DefuserTest : public HloVerifiedTestBase { + protected: + // Returns the number of fusion instructions in the module. + int FusionCount() { + int count = 0; + for (HloComputation* computation : module().computations()) { + if (computation->IsFusionComputation()) { + count++; + } + } + return count; + } + + Defuser defuser_; + const Shape shape_ = ShapeUtil::MakeShape(F32, {2, 2}); +}; + +TEST_F(DefuserTest, NoFusionInstruction) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + + module().AddEntryComputation(builder.Build()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_FALSE(defuser_.Run(&module()).ValueOrDie()); +} + +TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + + auto computation = module().AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({add}, + HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(computation->root_instruction(), op::Fusion()); + + EXPECT_EQ(1, FusionCount()); + EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Parameter(), op::Parameter())); +} + +TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + builder.AddInstruction( + HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); + + auto computation = module().AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({add}, + HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion())); + + EXPECT_EQ(1, FusionCount()); + EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_THAT(computation->root_instruction(), + op::Negate(op::Add(op::Parameter(), op::Parameter()))); +} + +TEST_F(DefuserTest, NonTrivialFusionInstruction) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + auto param3 = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3)); + auto div = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); + auto constant = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + auto add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); + + auto computation = module().AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction( + {add2, constant, div, mul, sub, negate, add}, + HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(computation->root_instruction(), op::Fusion()); + + EXPECT_EQ(1, FusionCount()); + EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Constant(), op::Divide())); +} + +TEST_F(DefuserTest, MultipleFusionInstructions) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + auto param3 = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3)); + auto div = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); + auto constant = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + auto add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); + + auto computation = module().AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({add2, constant, div, mul}, + HloInstruction::FusionKind::kLoop); + computation->CreateFusionInstruction({sub, negate, add}, + HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(computation->root_instruction(), op::Fusion()); + + EXPECT_EQ(2, FusionCount()); + EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Constant(), op::Divide())); +} + +TEST_F(DefuserTest, NestedFusionInstructions) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); + + auto computation = module().AddEntryComputation(builder.Build()); + auto outer_fusion = computation->CreateFusionInstruction( + {negate, add}, HloInstruction::FusionKind::kLoop); + HloInstruction* fused_negate = outer_fusion->fused_expression_root(); + ASSERT_EQ(fused_negate->opcode(), HloOpcode::kNegate); + outer_fusion->fused_instructions_computation()->CreateFusionInstruction( + {fused_negate}, HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(computation->root_instruction(), op::Fusion()); + + EXPECT_EQ(2, FusionCount()); + EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 2c16a1b9033f45742f80b91eb1695315bd13ed80..adaff90913d391f2417dcba105a147a0635d88c4 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -85,6 +85,10 @@ class DfsHloVisitor { virtual Status HandleCopy(HloInstruction* copy) { return HandleElementwiseUnary(copy); } + virtual Status HandleComplex(HloInstruction* complex, HloInstruction* real, + HloInstruction* imag) { + return HandleElementwiseBinary(complex); + } virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, HloInstruction* rhs) { return HandleElementwiseBinary(multiply); @@ -122,6 +126,10 @@ class DfsHloVisitor { virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { return HandleElementwiseUnary(abs); } + virtual Status HandleAtan2(HloInstruction* atan2, HloInstruction* y, + HloInstruction* x) { + return HandleElementwiseBinary(atan2); + } virtual Status HandleRound(HloInstruction* round) { return HandleElementwiseUnary(round); } @@ -152,22 +160,42 @@ class DfsHloVisitor { virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) { return HandleElementwiseUnary(tanh); } + virtual Status HandleReal(HloInstruction* real, HloInstruction* operand) { + return HandleElementwiseUnary(real); + } + virtual Status HandleImag(HloInstruction* imag, HloInstruction* operand) { + return HandleElementwiseUnary(imag); + } virtual Status HandleIsFinite(HloInstruction* is_finite, HloInstruction* operand) { return HandleElementwiseUnary(is_finite); } - virtual Status HandleLogicalAnd(HloInstruction* logical_and, - HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_and); + virtual Status HandleAnd(HloInstruction* and_, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(and_); } - virtual Status HandleLogicalNot(HloInstruction* logical_not, - HloInstruction* operand) { - return HandleElementwiseUnary(logical_not); + virtual Status HandleNot(HloInstruction* not_, HloInstruction* operand) { + return HandleElementwiseUnary(not_); } - virtual Status HandleLogicalOr(HloInstruction* logical_or, + virtual Status HandleOr(HloInstruction* or_, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(or_); + } + virtual Status HandleShiftLeft(HloInstruction* shift_left, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_or); + return HandleElementwiseBinary(shift_left); + } + virtual Status HandleShiftRightArithmetic( + HloInstruction* shift_right_arithmetic, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(shift_right_arithmetic); } + virtual Status HandleShiftRightLogical(HloInstruction* shift_right_logical, + HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(shift_right_logical); + } + virtual Status HandleReducePrecision(HloInstruction* reduce_precision) { return HandleElementwiseUnary(reduce_precision); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 7117ecb08b2c1f83f155ca3d25d2831b5e411bc6..fd4c332cba94513ec5b4cd88a842189e716f35d5 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -54,10 +54,12 @@ StatusOr ElementalIrEmitter::EmitUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { if (op->opcode() == HloOpcode::kCopy) { return operand_value; + } else if (operand_value->getType()->isIntegerTy()) { + return EmitIntegerUnaryOp(op, operand_value); + } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) { + return EmitComplexUnaryOp(op, operand_value); } else { - return operand_value->getType()->isIntegerTy() - ? EmitIntegerUnaryOp(op, operand_value) - : EmitFloatUnaryOp(op, operand_value); + return EmitFloatUnaryOp(op, operand_value); } } @@ -73,20 +75,35 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (primitive_util::IsIntegralType(to_type)) { return ir_builder_->CreateIntCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_), + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), primitive_util::IsSignedIntegralType(to_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (primitive_util::IsSignedIntegralType(from_type)) { return ir_builder_->CreateSIToFP( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { return ir_builder_->CreateUIToFP( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + } + } + if (primitive_util::IsComplexType(to_type)) { + auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( + primitive_util::ComplexComponentType(to_type), module_); + if (primitive_util::IsSignedIntegralType(from_type)) { + return ComposeComplex( + op, + ir_builder_->CreateSIToFP(operand_value, to_ir_component_type), + nullptr); + } + if (primitive_util::IsUnsignedIntegralType(from_type) || + from_type == PRED) { + return ComposeComplex( + op, + ir_builder_->CreateUIToFP(operand_value, to_ir_component_type), + nullptr); } } return Unimplemented("conversion from primitive type %s to %s", @@ -97,8 +114,8 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( bool is_signed = primitive_util::IsSignedIntegralType(op->shape().element_type()); if (is_signed) { - auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), - ir_builder_); + auto type = + llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto zero = llvm::ConstantInt::get(type, 0); auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero); return ir_builder_->CreateSelect(cmp, operand_value, @@ -110,8 +127,8 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( case HloOpcode::kSign: { bool is_signed = primitive_util::IsSignedIntegralType(op->shape().element_type()); - auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), - ir_builder_); + auto type = + llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto zero = llvm::ConstantInt::get(type, 0); auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero); if (is_signed) { @@ -126,14 +143,21 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } case HloOpcode::kNegate: return ir_builder_->CreateNeg(operand_value); - case HloOpcode::kLogicalNot: - // It is not sufficient to just call CreateNot() here because a PRED is - // represented as an i8 and the truth value is stored only in the bottom - // bit. - return ir_builder_->CreateZExt( - ir_builder_->CreateNot(ir_builder_->CreateTrunc( - operand_value, ir_builder_->getInt1Ty())), - llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_)); + case HloOpcode::kNot: { + auto type = op->shape().element_type(); + if (type == PRED) { + // It is not sufficient to just call CreateNot() here because a PRED + // is represented as an i8 and the truth value is stored only in the + // bottom bit. + return ir_builder_->CreateZExt( + ir_builder_->CreateNot(ir_builder_->CreateTrunc( + operand_value, ir_builder_->getInt1Ty())), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + } else if (primitive_util::IsIntegralType(type)) { + return ir_builder_->CreateNot(operand_value); + } + return Unimplemented("unary op Not is not defined for type '%d'", type); + } default: return Unimplemented("unary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -150,20 +174,30 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( if (from_type == to_type) { return operand_value; } + if (primitive_util::IsComplexType(to_type)) { + PrimitiveType to_component_type = + primitive_util::ComplexComponentType(to_type); + if (from_type == to_component_type) { + return ComposeComplex(op, operand_value, nullptr); + } + return ComposeComplex( + op, + ir_builder_->CreateFPCast( + operand_value, + llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), + nullptr); + } if (primitive_util::IsFloatingPointType(to_type)) { return ir_builder_->CreateFPCast( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsSignedIntegralType(to_type)) { return ir_builder_->CreateFPToSI( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsUnsignedIntegralType(to_type)) { return ir_builder_->CreateFPToUI( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return Unimplemented("unhandled conversion operation: %s => %s", PrimitiveType_Name(from_type).c_str(), @@ -223,7 +257,7 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite); return ir_builder_->CreateZExt( - result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_)); + result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: return ir_builder_->CreateFNeg(operand_value); @@ -233,20 +267,164 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } } +StatusOr ElementalIrEmitter::EmitComplexUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const { + auto real = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {1}); + }; + switch (op->opcode()) { + // TODO(b/65209142): Angle/Log require atan2. + // case HloOpcode::kAngle: + // case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + case HloOpcode::kConvert: { + PrimitiveType from_type = op->operand(0)->shape().element_type(); + TF_RET_CHECK(primitive_util::IsComplexType(from_type)); + PrimitiveType to_type = op->shape().element_type(); + TF_RET_CHECK(primitive_util::IsComplexType(to_type)); + if (from_type == to_type) { + return operand_value; + } + PrimitiveType to_component_type = + primitive_util::ComplexComponentType(to_type); + auto to_ir_component_type = + llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); + return ComposeComplex( + op, + ir_builder_->CreateFPCast(real(operand_value), to_ir_component_type), + ir_builder_->CreateFPCast(imag(operand_value), to_ir_component_type)); + } + case HloOpcode::kExp: { + // e^(a+bi) = e^a*(cos(b)+sin(b)i) + auto exp_a = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::exp, {real(operand_value)}, + {real(operand_value)->getType()}, ir_builder_); + auto cos_b = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::cos, {imag(operand_value)}, + {imag(operand_value)->getType()}, ir_builder_); + auto sin_b = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::sin, {imag(operand_value)}, + {imag(operand_value)->getType()}, ir_builder_); + return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); + } + case HloOpcode::kCos: { + // cos(z) = .5(e^(iz) + e^(-iz)) + // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai)) + // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have + // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i)) + // cos(-x) = cos(x) and sin(-x) = -sin(x), so + // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i)) + // = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) + auto a = real(operand_value); + auto b = imag(operand_value); + auto type = a->getType(); + auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, + {type}, ir_builder_); + auto half_exp_b = + ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = + ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, + {type}, ir_builder_); + auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, + {type}, ir_builder_); + return ComposeComplex( + op, + ir_builder_->CreateFMul( + cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), + ir_builder_->CreateFMul( + sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b))); + } + case HloOpcode::kSin: { + // sin(z) = .5i(e^(-iz) - e^(iz)) + // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi))) + // = .5i(e^(b-ai) - e^(-b+ai)) + // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have + // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i)) + // = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a))) + // cos(-x) = cos(x) and sin(-x) = -sin(x), so + // = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a))) + // = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) + auto a = real(operand_value); + auto b = imag(operand_value); + auto type = a->getType(); + auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, + {type}, ir_builder_); + auto half_exp_b = + ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = + ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, + {type}, ir_builder_); + auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, + {type}, ir_builder_); + return ComposeComplex( + op, + ir_builder_->CreateFMul( + sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), + ir_builder_->CreateFMul( + cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); + } + case HloOpcode::kAbs: { + auto sum_sq = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(operand_value), real(operand_value)), + ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, + {sum_sq->getType()}, ir_builder_); + } + case HloOpcode::kSign: { // Sign(c) = c / |c| + auto sum_sq = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(operand_value), real(operand_value)), + ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + auto cplx_abs = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); + auto type = cplx_abs->getType(); + auto zero = llvm::ConstantFP::get(type, 0.0); + auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero); + return ir_builder_->CreateSelect( + oeq, ComposeComplex(op, zero, zero), + ComposeComplex( + op, ir_builder_->CreateFDiv(real(operand_value), cplx_abs), + ir_builder_->CreateFDiv(imag(operand_value), cplx_abs))); + } + case HloOpcode::kNegate: + return ComposeComplex(op, ir_builder_->CreateFNeg(real(operand_value)), + ir_builder_->CreateFNeg(imag(operand_value))); + case HloOpcode::kReal: + return real(operand_value); + case HloOpcode::kImag: + return imag(operand_value); + default: + return Unimplemented("unary complex op '%s'", + HloOpcodeString(op->opcode()).c_str()); + } +} + StatusOr ElementalIrEmitter::EmitBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { - return lhs_value->getType()->isIntegerTy() - ? EmitIntegerBinaryOp(op, lhs_value, rhs_value, - primitive_util::IsSignedIntegralType( - op->operand(0)->shape().element_type())) - : EmitFloatBinaryOp(op, lhs_value, rhs_value); + PrimitiveType operand_type = op->operand(0)->shape().element_type(); + if (lhs_value->getType()->isIntegerTy()) { + return EmitIntegerBinaryOp( + op, lhs_value, rhs_value, + primitive_util::IsSignedIntegralType(operand_type)); + } else if (primitive_util::IsComplexType(operand_type)) { + return EmitComplexBinaryOp(op, lhs_value, rhs_value); + } else { + return EmitFloatBinaryOp(op, lhs_value, rhs_value); + } } StatusOr ElementalIrEmitter::EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { switch (op->opcode()) { + // case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support + case HloOpcode::kComplex: + return ComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: return ir_builder_->CreateFAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: @@ -298,6 +476,88 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( } } +StatusOr ElementalIrEmitter::EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const { + auto real = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {1}); + }; + switch (op->opcode()) { + case HloOpcode::kAdd: + return ComposeComplex( + op, ir_builder_->CreateFAdd(real(lhs_value), real(rhs_value)), + ir_builder_->CreateFAdd(imag(lhs_value), imag(rhs_value))); + case HloOpcode::kSubtract: + return ComposeComplex( + op, ir_builder_->CreateFSub(real(lhs_value), real(rhs_value)), + ir_builder_->CreateFSub(imag(lhs_value), imag(rhs_value))); + case HloOpcode::kMultiply: + return ComposeComplex( + op, + ir_builder_->CreateFSub( + ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), + ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))), + ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), + ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)))); + case HloOpcode::kDivide: { + // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) + // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) + auto rhs_sum_sq = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(rhs_value), real(rhs_value)), + ir_builder_->CreateFMul(imag(rhs_value), imag(rhs_value))); + auto type = rhs_sum_sq->getType(); + auto zero = llvm::ConstantFP::get(type, 0.0); + auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero); + return ir_builder_->CreateSelect( + oeq, ComposeComplex(op, llvm::ConstantFP::getInfinity(type), zero), + ComposeComplex( + op, + ir_builder_->CreateFDiv( + ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), + ir_builder_->CreateFMul(imag(lhs_value), + imag(rhs_value))), + rhs_sum_sq), + ir_builder_->CreateFDiv( + ir_builder_->CreateFSub( + ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)), + ir_builder_->CreateFMul(real(lhs_value), + imag(rhs_value))), + rhs_sum_sq))); + } + // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered + // comparisons always return false when one of the operands is NaN, whereas + // unordered comparisons return true. + // + // We use ordered comparisons for everything except kNe, where we use an + // unordered comparison. This makes x != y equivalent to !(x == y), and + // matches C++'s semantics. + case HloOpcode::kEq: + return ir_builder_->CreateAnd( + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, real(lhs_value), + real(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, imag(lhs_value), + imag(rhs_value), ir_builder_)); + case HloOpcode::kNe: + return ir_builder_->CreateOr( + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, real(lhs_value), + real(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, imag(lhs_value), + imag(rhs_value), ir_builder_)); + + // TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic + // case HloOpcode::kPower: + // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(c/2+di/2) + default: + return Unimplemented("binary complex op '%s'", + HloOpcodeString(op->opcode()).c_str()); + } +} + llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value) const { return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_); @@ -389,7 +649,7 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, StatusOr ElementalIrEmitter::EmitErfcInv( PrimitiveType prim_type, llvm::Value* value) const { // Compute erfcinv(value) by calculating erfinv(1.0 - value). - auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, ir_builder_); + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); } @@ -557,10 +817,16 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, lhs_value, rhs_value), lhs_value, rhs_value); - case HloOpcode::kLogicalAnd: + case HloOpcode::kAnd: return ir_builder_->CreateAnd(lhs_value, rhs_value); - case HloOpcode::kLogicalOr: + case HloOpcode::kOr: return ir_builder_->CreateOr(lhs_value, rhs_value); + case HloOpcode::kShiftLeft: + return ir_builder_->CreateShl(lhs_value, rhs_value); + case HloOpcode::kShiftRightArithmetic: + return ir_builder_->CreateAShr(lhs_value, rhs_value); + case HloOpcode::kShiftRightLogical: + return ir_builder_->CreateLShr(lhs_value, rhs_value); default: return Unimplemented("binary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -606,7 +872,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( const { PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type(); llvm::Type* param_ir_type = - llvm_ir::PrimitiveTypeToIrType(param_prim_type, ir_builder_); + llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_); // Same values as PCG library // https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h @@ -770,7 +1036,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( return ir_builder_->CreateZExt( ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p), llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - ir_builder_)); + module_)); } default: return InvalidArgument( @@ -793,13 +1059,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kNegate: + case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kTanh: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, @@ -808,6 +1076,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return EmitUnaryOp(hlo, operand_value); }; case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kGe: @@ -821,8 +1091,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { const HloInstruction* lhs = hlo->operand(0); @@ -879,17 +1152,31 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const int64 concat_dim = hlo->dimensions(0); auto source_index = target_index; - llvm::PHINode* output = ir_builder_->CreatePHI( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - ir_builder_), - hlo->operands().size()); llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); + + // A terminator should be present iff we're emitting code + // into the middle (as opposed to the end) of a basic block. + CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), + init_block->getTerminator() == nullptr); + + llvm::BasicBlock* exit_block; + if (ir_builder_->GetInsertPoint() == init_block->end()) { + exit_block = llvm_ir::CreateBasicBlock( + /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); + } else { + exit_block = init_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge"))); + init_block->getTerminator()->eraseFromParent(); + } + + llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); + llvm::PHINode* output = + ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType( + hlo->shape().element_type(), module_), + hlo->operands().size()); auto prior_insert_point = ir_builder_->GetInsertPoint(); - llvm::BasicBlock* exit_block = - init_block->splitBasicBlock(output, "concat_merge"); ir_builder_->SetInsertPoint(init_block); - init_block->getTerminator()->eraseFromParent(); for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { @@ -1045,7 +1332,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( // else -> return data from 'index'. llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - ir_builder_), + module_), "ret_value_addr", ir_builder_); llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( slice_intersection, "slice_intersection", ir_builder_); @@ -1134,7 +1421,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( // } llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - ir_builder_), + module_), "pad_result_addr", ir_builder_); llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); @@ -1176,7 +1463,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( ir_builder_); PrimitiveType primitive_type = hlo->shape().element_type(); llvm::Type* primitive_type_llvm = - llvm_ir::PrimitiveTypeToIrType(primitive_type, ir_builder_); + llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( primitive_type_llvm, "dot_acc", ir_builder_); ir_builder_->CreateStore( @@ -1209,7 +1496,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; - if (primitive_util::IsFloatingPointType(primitive_type)) { + if (primitive_util::IsComplexType(primitive_type)) { + auto real = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {1}); + }; + llvm::Value* product_real = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), + ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))); + llvm::Value* product_imag = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), + ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value))); + next_accumulator = ir_builder_->CreateInsertValue( + current_accumulator, + ir_builder_->CreateFAdd(real(current_accumulator), product_real), + {0}); + next_accumulator = ir_builder_->CreateInsertValue( + next_accumulator, + ir_builder_->CreateFAdd(imag(current_accumulator), product_imag), + {1}); + } else if (primitive_util::IsFloatingPointType(primitive_type)) { next_accumulator = ir_builder_->CreateFAdd( current_accumulator, ir_builder_->CreateFMul(lhs_value, rhs_value)); @@ -1231,4 +1539,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( } } +llvm::Value* ElementalIrEmitter::ComposeComplex(const HloInstruction* op, + llvm::Value* real, + llvm::Value* imag) const { + auto cplx_type = + llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); + auto complex = ir_builder_->CreateInsertValue( + llvm::ConstantAggregateZero::get(cplx_type), real, {0}); + if (imag != nullptr) { + complex = ir_builder_->CreateInsertValue(complex, imag, {1}); + } + return complex; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 35dfa88e9b02e3ec7686dc7fdded8cf4e88201fb..9d32436e38fa2fb3e27d09f01b860cd2edf2c8ac 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -55,6 +55,7 @@ class ElementalIrEmitter { const HloToElementGeneratorMap& operand_to_generator) const; llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } + llvm::Module* module() const { return module_; } protected: virtual StatusOr EmitIntegerUnaryOp( @@ -63,6 +64,9 @@ class ElementalIrEmitter { virtual StatusOr EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitComplexUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitIntegerBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, @@ -72,6 +76,10 @@ class ElementalIrEmitter { const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const; + virtual StatusOr EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const; + virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value) const; @@ -109,6 +117,11 @@ class ElementalIrEmitter { // compiled executable outside of the HLO code itself. const HloModuleConfig& hlo_module_config_; + protected: + // Composes a complex struct. imag may be nullptr for simple cast operations. + llvm::Value* ComposeComplex(const HloInstruction* op, llvm::Value* real, + llvm::Value* imag) const; + private: // Returns a ElementGenerator for a RNG HloInstruction. llvm_ir::ElementGenerator MakeRngElementGenerator( diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 62b8fa6a2b77e21ae3aa257935f5a22e3e8a130b..9c96d9eb30b5f9e51b7f5d82391c6b9f366898d6 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" @@ -82,7 +84,11 @@ Status Executable::DumpSessionModule() { } filename = SanitizeFileName(std::move(filename)); string file_path = tensorflow::io::JoinPath(directory_path, filename); - return tensorflow::WriteBinaryProto(env, file_path, session_module); + string result; + TF_RET_CHECK( + tensorflow::SerializeToStringDeterministic(session_module, &result)); + return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, + result); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 82c32407d3d10dffc89c448ee402ea3a789a0c39..de84e06cebab72d272bd888f280f5e5b221b97d1 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -104,7 +104,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "@llvm//:core", ], @@ -147,6 +147,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc index 7cf5613ce571cadc5ad45ade17341376f3d6ae39..5aaf072f9d2c95e2fff70a1c5337432a12a1aa48 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -72,8 +72,10 @@ MatchBackwardFilter(HloInstruction* conv) { // Step 2: match paddings and dimension numbers of the forward convolution. const ConvolutionDimensionNumbers& conv_dnums = conv->convolution_dimension_numbers(); - auto batch_dim = conv_dnums.batch_dimension(); - auto feature_dim = conv_dnums.feature_dimension(); + auto input_batch_dim = conv_dnums.input_batch_dimension(); + auto input_feature_dim = conv_dnums.input_feature_dimension(); + auto output_batch_dim = conv_dnums.output_batch_dimension(); + auto output_feature_dim = conv_dnums.output_feature_dimension(); auto spatial_dims = conv_dnums.spatial_dimensions(); for (const WindowDimension& window_dim : conv->window().dimensions()) { @@ -183,8 +185,10 @@ MatchBackwardFilter(HloInstruction* conv) { // convolution. The two activation dimensions are reversed (batch and // feature). ConvolutionDimensionNumbers backward_conv_dnums; - backward_conv_dnums.set_batch_dimension(feature_dim); - backward_conv_dnums.set_feature_dimension(batch_dim); + backward_conv_dnums.set_input_batch_dimension(input_feature_dim); + backward_conv_dnums.set_input_feature_dimension(input_batch_dim); + backward_conv_dnums.set_output_batch_dimension(output_feature_dim); + backward_conv_dnums.set_output_feature_dimension(output_batch_dim); for (int i = 0; i < spatial_dims.size(); ++i) { backward_conv_dnums.add_spatial_dimensions(spatial_dims[i]); } @@ -198,9 +202,9 @@ MatchBackwardFilter(HloInstruction* conv) { // the dimension numbering of the weight gradients. This transposition maps // dimension i to PositionInContainer(transpose->dimensions(), i). backward_conv_dnums.set_kernel_input_feature_dimension( - PositionInContainer(transpose->dimensions(), batch_dim)); + PositionInContainer(transpose->dimensions(), output_batch_dim)); backward_conv_dnums.set_kernel_output_feature_dimension( - PositionInContainer(transpose->dimensions(), feature_dim)); + PositionInContainer(transpose->dimensions(), output_feature_dim)); for (int i = 0; i < spatial_dims.size(); ++i) { backward_conv_dnums.add_kernel_spatial_dimensions( PositionInContainer(transpose->dimensions(), spatial_dims[i])); @@ -275,7 +279,7 @@ MatchBackwardInput(HloInstruction* conv) { Window new_window = old_window; for (size_t i = 0; i < spatial_dims.size(); ++i) { // Restore backward convolution's padding config from the matched pattern. - // See the comment in tensorflow/core/kernels/conv_grad_ops.cc + // See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc // for how we convert backward input convolution to a variant of forward // convolution. // diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc index 6699c8f3c4acd76ed58cccf314ca0ae1502d51d7..19b122ba0603b4ec08d73e05da4c2ae11a760553 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc @@ -45,8 +45,10 @@ class ConvolutionFoldingTest : public HloTestBase { // dimension in gradients as the input feature dimension in the filter. // // TODO(jingyue): Add more tests on NCHW input order which TF also supports. - tf_default_dnums_for_backward_filter_.set_batch_dimension(3); - tf_default_dnums_for_backward_filter_.set_feature_dimension(0); + tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); + tf_default_dnums_for_backward_filter_.set_output_batch_dimension(3); + tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); + tf_default_dnums_for_backward_filter_.set_output_feature_dimension(0); tf_default_dnums_for_backward_filter_.add_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_spatial_dimensions(2); tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); @@ -55,8 +57,10 @@ class ConvolutionFoldingTest : public HloTestBase { tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); - tf_default_dnums_for_backward_input_.set_batch_dimension(0); - tf_default_dnums_for_backward_input_.set_feature_dimension(3); + tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); + tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); + tf_default_dnums_for_backward_input_.set_input_feature_dimension(3); + tf_default_dnums_for_backward_input_.set_output_feature_dimension(3); tf_default_dnums_for_backward_input_.add_spatial_dimensions(1); tf_default_dnums_for_backward_input_.add_spatial_dimensions(2); tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3); @@ -250,8 +254,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { conv_window.mutable_dimensions(i)->set_padding_high(3); } ConvolutionDimensionNumbers conv_dnums; - conv_dnums.set_batch_dimension(0); - conv_dnums.set_feature_dimension(1); + conv_dnums.set_input_batch_dimension(0); + conv_dnums.set_output_batch_dimension(0); + conv_dnums.set_input_feature_dimension(1); + conv_dnums.set_output_feature_dimension(1); conv_dnums.add_spatial_dimensions(2); conv_dnums.add_spatial_dimensions(3); conv_dnums.set_kernel_input_feature_dimension(0); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 7dd242425c9841e63259e8d54b9fedc203a65af5..536b96dcf620e908e25a775bc2efb57ba5f5edd6 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -141,8 +141,8 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( BatchDescriptor input_descriptor(effective_num_dimensions); input_descriptor.set_layout(DataLayout::kBatchDepthYX) .set_feature_map_count( - input_shape_.dimensions(dim_nums_.feature_dimension())) - .set_count(input_shape_.dimensions(dim_nums_.batch_dimension())); + input_shape_.dimensions(dim_nums_.input_feature_dimension())) + .set_count(input_shape_.dimensions(dim_nums_.input_batch_dimension())); for (int dim = 0; dim < num_dimensions; ++dim) { // Note that the dimensions are reversed. The same holds below. input_descriptor.set_spatial_dim( @@ -176,8 +176,8 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( BatchDescriptor output_descriptor(effective_num_dimensions); output_descriptor.set_layout(DataLayout::kBatchDepthYX) .set_feature_map_count( - output_shape_.dimensions(dim_nums_.feature_dimension())) - .set_count(output_shape_.dimensions(dim_nums_.batch_dimension())); + output_shape_.dimensions(dim_nums_.output_feature_dimension())) + .set_count(output_shape_.dimensions(dim_nums_.output_batch_dimension())); for (int dim = 0; dim < num_dimensions; ++dim) { output_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index 87858e94090d1f7506ee09b9015b4417aee55707..f4498663b1c039b3175376baf8f27c4ecec678ec 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -20,15 +20,16 @@ limitations under the License. namespace xla { namespace gpu { -CopyThunk::CopyThunk(const void* source_address, - const BufferAllocation::Slice& destination_buffer, - uint64 mem_size, const HloInstruction* hlo_instruction) +HostToDeviceCopyThunk::HostToDeviceCopyThunk( + const void* source_address, + const BufferAllocation::Slice& destination_buffer, uint64 mem_size, + const HloInstruction* hlo_instruction) : Thunk(Kind::kCopy, hlo_instruction), source_address_(source_address), destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status CopyThunk::ExecuteOnStream( +tensorflow::Status HostToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) { perftools::gputools::DeviceMemoryBase destination_data = @@ -37,5 +38,24 @@ tensorflow::Status CopyThunk::ExecuteOnStream( return tensorflow::Status::OK(); } +DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( + const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, uint64 mem_size, + const HloInstruction* hlo_instruction) + : Thunk(Kind::kCopy, hlo_instruction), + source_buffer_(source_buffer), + destination_buffer_(destination_buffer), + mem_size_(mem_size) {} + +tensorflow::Status DeviceToDeviceCopyThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { + perftools::gputools::DeviceMemoryBase destination_data = + buffer_allocations.GetDeviceAddress(destination_buffer_); + perftools::gputools::DeviceMemoryBase source_data = + buffer_allocations.GetDeviceAddress(source_buffer_); + stream->ThenMemcpy(&destination_data, source_data, mem_size_); + return tensorflow::Status::OK(); +} } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 6b8c432715f27fc02b13fc242db5ee6db098c47e..e2783fd255239d31edc89701ea208f33ebb8d3fb 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -26,19 +26,18 @@ limitations under the License. namespace xla { namespace gpu { -// A thunk that copies data. For now, it copies data only from host to device. -// But it can be extended to perform device-to-host or intra-device copying. -class CopyThunk : public Thunk { +// A thunk that copies data from a host buffer to a device buffer. +class HostToDeviceCopyThunk : public Thunk { public: // Constructs a CopyThunk that copies host data from `source_address` to the // device buffer `destination_buffer`. `mem_size` is the size of the data in // bytes. - CopyThunk(const void* source_address, - const BufferAllocation::Slice& destination_buffer, uint64 mem_size, - const HloInstruction* hlo_instruction); + HostToDeviceCopyThunk(const void* source_address, + const BufferAllocation::Slice& destination_buffer, + uint64 mem_size, const HloInstruction* hlo_instruction); - CopyThunk(const CopyThunk&) = delete; - CopyThunk& operator=(const CopyThunk&) = delete; + HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete; + HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; tensorflow::Status ExecuteOnStream( const BufferAllocations& buffer_allocations, @@ -50,6 +49,30 @@ class CopyThunk : public Thunk { const uint64 mem_size_; }; +// A thunk that copies data from a device buffer to another device buffer. +class DeviceToDeviceCopyThunk : public Thunk { + public: + // Constructs a CopyThunk that copies host data from `source_buffer` to the + // device buffer `destination_buffer`. `mem_size` is the size of the data in + // bytes. + DeviceToDeviceCopyThunk(const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, + uint64 mem_size, + const HloInstruction* hlo_instruction); + + DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; + DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; + + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + const BufferAllocation::Slice source_buffer_; + const BufferAllocation::Slice destination_buffer_; + const uint64 mem_size_; +}; + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 8810a85ceeafd8b2d9ad8d7412266847abe5b75d..1b94499bc6ef6d587cdb1fafec48bc4e5b917c51 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -135,6 +135,10 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { + case HloOpcode::kAtan2: + return EmitLibdeviceMathCall("__nv_atan2", {lhs_value, rhs_value}, + {lhs_input_type, rhs_input_type}, + output_type); case HloOpcode::kRemainder: { return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value}, {lhs_input_type, rhs_input_type}, @@ -226,6 +230,112 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( } } +StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const { + PrimitiveType input_type = op->operand(0)->shape().element_type(); + PrimitiveType component_type = + primitive_util::IsComplexType(input_type) + ? primitive_util::ComplexComponentType(input_type) + : input_type; + auto real = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {1}); + }; + + switch (op->opcode()) { + case HloOpcode::kLog: { + // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + auto a = real(operand_value); + auto b = imag(operand_value); + llvm::Type* llvm_ty = a->getType(); + auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + TF_ASSIGN_OR_RETURN( + auto log_sum_sq, + EmitLibdeviceMathCall("__nv_log", {sum_sq}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto angle, EmitLibdeviceMathCall("__nv_atan2", {b, a}, + {component_type, component_type}, + component_type)); + auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); + return ComposeComplex(op, ir_builder_->CreateFMul(one_half, log_sum_sq), + angle); + } + // TODO(b/65408531): Implement kPower on GPU, where atan2 is available. + // case HloOpcode::kPower: + // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(0.5(c+di)) + case HloOpcode::kExp: { + // e^(a+bi) = e^a*(cos(b)+sin(b)i) + auto b = imag(operand_value); + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitLibdeviceMathCall("__nv_exp", {real(operand_value)}, + {component_type}, component_type)); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, + component_type)); + return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); + } + case HloOpcode::kCos: { + // cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) + auto a = real(operand_value); + auto llvm_ty = a->getType(); + TF_ASSIGN_OR_RETURN( + auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, + {component_type}, component_type)); + TF_ASSIGN_OR_RETURN( + auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, + component_type)); + auto half_exp_b = + ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); + auto half_exp_neg_b = + ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); + return ComposeComplex( + op, + ir_builder_->CreateFMul( + cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), + ir_builder_->CreateFMul( + sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b))); + } + + case HloOpcode::kSin: { + // sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) + auto a = real(operand_value); + auto llvm_ty = a->getType(); + TF_ASSIGN_OR_RETURN( + auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, + {component_type}, component_type)); + TF_ASSIGN_OR_RETURN( + auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, + component_type)); + auto half_exp_b = + ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); + auto half_exp_neg_b = + ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); + return ComposeComplex( + op, + ir_builder_->CreateFMul( + sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), + ir_builder_->CreateFMul( + cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); + } + default: + return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value); + } +} + llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( const string& callee_name, tensorflow::gtl::ArraySlice operands, @@ -235,13 +345,12 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( std::vector ir_input_types; for (PrimitiveType input_type : input_types) { ir_input_types.push_back( - llvm_ir::PrimitiveTypeToIrType(input_type, ir_builder_)); + llvm_ir::PrimitiveTypeToIrType(input_type, module_)); } llvm::FunctionType* callee_type = llvm::FunctionType::get( - llvm_ir::PrimitiveTypeToIrType(output_type, - ir_builder_), // The return type. - ir_input_types, // The parameter types. - false); // No variadic arguments. + llvm_ir::PrimitiveTypeToIrType(output_type, module_), // Return type. + ir_input_types, // Parameter types. + false); // No variadic arguments. // Declares the callee if it is not declared already. llvm::Function* callee = llvm::cast( @@ -315,7 +424,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, ir_builder_), + llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), "reduce_window_accum_ptr", ir_builder_); { TF_ASSIGN_OR_RETURN(llvm::Value * init_value, @@ -377,7 +486,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); llvm::Value* accum_ptr = ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( - hlo->shape().element_type(), ir_builder())); + hlo->shape().element_type(), module_)); TF_ASSIGN_OR_RETURN(llvm::Value * init_value, operand_to_generator.at(hlo->operand(1))({})); ir_builder()->CreateStore(init_value, accum_ptr); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 6ddfc3710c56a4e129f050f862812a3d78d8dba0..3defa1b696d3addc012702e23102bb1fa140170d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -54,6 +54,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; + StatusOr EmitComplexUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const override; + StatusOr EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 0bcdf8a61de112e7f337653007b70f35c4924365..b5331fe4e2ba34443555e9bf46dfc188cbd6548a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -67,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -84,6 +85,8 @@ namespace gpu { namespace { +using tensorflow::strings::StrCat; + // Any address of a variable residing in global memory or returned by one of the // memory allocation routines from the driver or runtime API is always aligned // to at least 256 bytes. @@ -148,6 +151,7 @@ tensorflow::Status OptimizeHloModule( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); } @@ -223,7 +227,7 @@ tensorflow::Status PrepareHloModuleForIrEmitting( } // Invokes the ptxas tool on the given PTX string, and dumps its output. -void DumpPtxasInfo(const string& ptx) { +void DumpPtxasInfo(const string& ptx, int cc_major, int cc_minor) { const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas"); // Do not log PTX stats if ptxas is not found at the given path. @@ -245,17 +249,22 @@ void DumpPtxasInfo(const string& ptx) { // Invoke ptxas and collect its output. tensorflow::SubProcess ptxas_info_dumper; - ptxas_info_dumper.SetProgram(ptxas_path, {ptxas_path, ptx_path, "-o", - "/dev/null", "-v", "-arch=sm_35"}); + ptxas_info_dumper.SetProgram(ptxas_path, + {ptxas_path, ptx_path, "-o", "/dev/null", "-v", + StrCat("-arch=sm_", cc_major, cc_minor)}); ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); - CHECK(ptxas_info_dumper.Start()); + if (!ptxas_info_dumper.Start()) { + LOG(ERROR) << "Failed to launch ptxas."; + return; + } string stderr_output; int exit_status = ptxas_info_dumper.Communicate( /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); XLA_LOG_LINES(tensorflow::INFO, stderr_output); if (exit_status != 0) { - LOG(FATAL) << "Invalid PTX. See the error message above for reasons."; + LOG(ERROR) << "ptxas exited with non-zero error code " << exit_status + << "."; } } @@ -310,12 +319,12 @@ StatusOr> GpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, buffer_assignment->ToString()); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); - if (!dump_debug_json_to.empty()) { + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *buffer_assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), @@ -387,7 +396,7 @@ StatusOr> GpuCompiler::Compile( VLOG(2) << "PTX:"; XLA_VLOG_LINES(2, *ptx); if (VLOG_IS_ON(2)) { - DumpPtxasInfo(*ptx); + DumpPtxasInfo(*ptx, cc_major, cc_minor); } auto thunk_schedule = MakeUnique( @@ -408,7 +417,7 @@ StatusOr> GpuCompiler::Compile( StatusOr>> GpuCompiler::Compile( std::vector> modules, - std::vector stream_execs) { + std::vector> stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on GPU."); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index b5ffeef44ff5993cf822263680baab8deadd1799..58e835e5ee3f77b7b5cb3579514b7501bed2a2a1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -46,7 +46,8 @@ class GpuCompiler : public LLVMCompiler { StatusOr>> Compile( std::vector> modules, - std::vector stream_exec) override; + std::vector> + stream_execs) override; StatusOr>> CompileAheadOfTime(std::vector> module, diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 373c1aa5f9582fdb5f03f17f8a90a5e640f7b54d..163a161353fdb90cee2968269d572b8414855551 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -67,7 +67,7 @@ void HloToIrBindings::EmitBasePointersForHlos( // Lookup allocation GetTupleElement operand. const BufferAllocation::Slice slice = buffer_assignment_ - ->GetUniqueTopLevelSlice(LatestNonGteAncestor(non_io_hlo)) + ->GetUniqueTopLevelSlice(non_io_hlo->LatestNonGteAncestor()) .ConsumeValueOrDie(); // We are not in a nested context, so check non-thread-local allocation. CHECK(!slice.allocation()->is_thread_local()); @@ -102,7 +102,7 @@ void HloToIrBindings::EmitBasePointersForHlos( slice_result.ConsumeValueOrDie(); if (slice.allocation()->is_thread_local()) { llvm::Type* pointee_type = - llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_); + llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); BindHloToIrValue(*non_io_hlo, ir_builder_->CreateAlloca(pointee_type), index); } else { @@ -124,18 +124,18 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_); + GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_); } return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_); + EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_); } llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, const ShapeIndex& shape_index, llvm::Value* ir_value) { llvm::Type* pointee_type = llvm_ir::ShapeToIrType( - ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_builder_); + ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); llvm::Type* dest_type = pointee_type->getPointerTo(); llvm::Value* typed_ir_value; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index d43e09e8a8c5cc2efcd8e1fbf9a7c0697e24d73c..a3120f15bcbfb0f2f0bfbd806e7a4ff05316d5dd 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -36,10 +36,12 @@ class HloToIrBindings { public: HloToIrBindings(const HloModule& module, const BufferAssignment* buffer_assignment, - llvm::IRBuilder<>* ir_builder, bool is_nested) + llvm::IRBuilder<>* ir_builder, llvm::Module* llvm_module, + bool is_nested) : buffer_assignment_(buffer_assignment), is_nested_(is_nested), ir_builder_(ir_builder), + module_(llvm_module), alias_analysis_(module, *buffer_assignment_, &ir_builder_->getContext()) {} @@ -93,6 +95,7 @@ class HloToIrBindings { const bool is_nested_; llvm::IRBuilder<>* ir_builder_; + llvm::Module* module_; // Stores the underlying llvm::IrArray for each HloInstruction. // For an instruction that generates multiple outputs, the root will be a diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 0b94594f1dc5cd040846eabaad01b4cd09520e12..9a4bfd0905bb62c02c70e7f2eea46872c07bca89 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -152,8 +152,10 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { conv_window_col->set_padding_high(1); ConvolutionDimensionNumbers conv_dnums; - conv_dnums.set_batch_dimension(0); - conv_dnums.set_feature_dimension(1); + conv_dnums.set_input_batch_dimension(0); + conv_dnums.set_output_batch_dimension(0); + conv_dnums.set_input_feature_dimension(1); + conv_dnums.set_output_feature_dimension(1); conv_dnums.add_spatial_dimensions(2); conv_dnums.add_spatial_dimensions(3); conv_dnums.set_kernel_output_feature_dimension(0); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 6be26dde8f957040c73db6a7e52f050e44d44c06..8fb7a6adda9dc7c36eb9aabcbcdc9d77e6c22c4a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -214,12 +214,5 @@ llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, value->getType()); } -const HloInstruction* LatestNonGteAncestor(const HloInstruction* hlo) { - while (hlo->opcode() == HloOpcode::kGetTupleElement) { - hlo = hlo->operand(0); - } - return hlo; -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 422972762ee3da793852429a71b4cee76e41e2bc..06c3205296e4546e39525ec093cc17e2fc375d0d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -53,10 +53,6 @@ llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); -// Resolves GetTupleElement instruction operands starting with 'hlo'. -// Returns the first ancestor instruction which is not a GetTupleElement. -const HloInstruction* LatestNonGteAncestor(const HloInstruction* hlo); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index a76d217cac271bcda950c1c325f67810dd513383..23765e05e8ac4f9f005cba166634cd48bb1e7c80 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -34,7 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -53,9 +53,10 @@ namespace gpu { IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, IrEmitterContext* ir_emitter_context, bool is_nested) : ir_emitter_context_(ir_emitter_context), - ir_builder_(ir_emitter_context->llvm_module()->getContext()), + module_(ir_emitter_context->llvm_module()), + ir_builder_(module_->getContext()), bindings_(ir_emitter_context->hlo_module(), - &ir_emitter_context->buffer_assignment(), &ir_builder_, + &ir_emitter_context->buffer_assignment(), &ir_builder_, module_, is_nested), hlo_module_config_(hlo_module_config) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( @@ -71,18 +72,17 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { }; } return EmitTargetElementLoop( - *hlo, GpuElementalIrEmitter(hlo_module_config_, - ir_emitter_context_->llvm_module(), - &ir_builder_, GetNestedComputer()) + *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, + GetNestedComputer()) .MakeElementGenerator(hlo, operand_to_generator)); } Status IrEmitter::HandleConstant(HloInstruction* constant, const Literal& literal) { llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_); + llvm_ir::ConvertLiteralToIrConstant(literal, module_); llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *ir_emitter_context_->llvm_module(), initializer->getType(), + *module_, initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, /*Name=*/""); VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl @@ -115,7 +115,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element, get_tuple_element->shape(), get_tuple_element->tuple_index(), // TODO(b/26344050): tighten the alignment here // based on the real element type. - /*alignment=*/1, GetBasePointer(*operand), &ir_builder_)); + /*alignment=*/1, GetBasePointer(*operand), &ir_builder_, module_)); return Status::OK(); } @@ -140,7 +140,7 @@ Status IrEmitter::HandleTuple( for (const HloInstruction* operand : operands) { base_ptrs.push_back(GetBasePointer(*operand)); } - llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_); + llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_, module_); return Status::OK(); } @@ -329,7 +329,7 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred, if (ShapeUtil::IsTuple(select->shape())) { llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred), GetBasePointer(*on_true), - GetBasePointer(*on_false), &ir_builder_); + GetBasePointer(*on_false), &ir_builder_, module_); return Status::OK(); } @@ -355,7 +355,26 @@ Status IrEmitter::HandleDot(HloInstruction* dot, lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); llvm::Value* rhs_value = rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); - llvm::Value* result = ir_builder_.CreateFMul(lhs_value, rhs_value); + llvm::Value* result; + if (ShapeUtil::ElementIsComplex(lhs_shape)) { + auto real = [&](llvm::Value* x) { + return ir_builder_.CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_.CreateExtractValue(x, {1}); + }; + llvm::Value* real_result = ir_builder_.CreateFSub( + ir_builder_.CreateFMul(real(lhs_value), real(rhs_value)), + ir_builder_.CreateFMul(imag(lhs_value), imag(rhs_value))); + llvm::Value* imag_result = ir_builder_.CreateFAdd( + ir_builder_.CreateFMul(real(lhs_value), imag(rhs_value)), + ir_builder_.CreateFMul(imag(lhs_value), real(rhs_value))); + result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); + result = ir_builder_.CreateInsertValue(result, real_result, {0}); + result = ir_builder_.CreateInsertValue(result, imag_result, {1}); + } else { + result = ir_builder_.CreateFMul(lhs_value, rhs_value); + } target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_); return Status::OK(); } @@ -411,8 +430,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot, // Initialize the accumulator in the preheader to zero. new llvm::StoreInst( - llvm::ConstantFP::get(accum_type, 0.0), // The value stored. - accum_address, // The address. + llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()), // init 0 + accum_address, // The address. reduction_loop->GetPreheaderBasicBlock() ->getTerminator()); // The instruction this store is inserted before. @@ -427,9 +446,27 @@ Status IrEmitter::HandleDot(HloInstruction* dot, lhs_array.EmitReadArrayElement(lhs_index, &ir_builder_); llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &ir_builder_); - llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element); llvm::Value* accum = ir_builder_.CreateLoad(accum_address); - llvm::Value* updated_accum = ir_builder_.CreateFAdd(accum, product); + llvm::Value* updated_accum; + if (ShapeUtil::ElementIsComplex(lhs_shape)) { +#define REAL(x) ir_builder_.CreateExtractValue(x, {0}) +#define IMAG(x) ir_builder_.CreateExtractValue(x, {1}) + llvm::Value* product_real = ir_builder_.CreateFSub( + ir_builder_.CreateFMul(REAL(lhs_element), REAL(rhs_element)), + ir_builder_.CreateFMul(IMAG(lhs_element), IMAG(rhs_element))); + llvm::Value* product_imag = ir_builder_.CreateFAdd( + ir_builder_.CreateFMul(REAL(lhs_element), IMAG(rhs_element)), + ir_builder_.CreateFMul(IMAG(lhs_element), REAL(rhs_element))); + updated_accum = ir_builder_.CreateInsertValue( + accum, ir_builder_.CreateFAdd(REAL(accum), product_real), {0}); + updated_accum = ir_builder_.CreateInsertValue( + updated_accum, ir_builder_.CreateFAdd(IMAG(accum), product_imag), {1}); +#undef IMAG +#undef REAL + } else { + llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element); + updated_accum = ir_builder_.CreateFAdd(accum, product); + } ir_builder_.CreateStore(updated_accum, accum_address); // After the reduction loop exits, store the accumulator into the target @@ -494,7 +531,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, // Initialize an accumulator with init_value. llvm::AllocaInst* accumulator_addr = ir_builder_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( - reduce->shape().element_type(), &ir_builder_)); + reduce->shape().element_type(), module_)); ir_builder_.CreateStore( ir_builder_.CreateLoad(GetBasePointer(*init_value)), accumulator_addr); @@ -547,8 +584,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand)); } - GpuElementalIrEmitter elemental_emitter(hlo_module_config_, - ir_emitter_context_->llvm_module(), + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &ir_builder_, GetNestedComputer()); FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); @@ -591,9 +627,8 @@ Status IrEmitter::HandleRng(HloInstruction* random, // Emits a single-threaded loop because the loop body generated by the element // generator for Rng can't be parallelized (b/32333178). return llvm_ir::LoopEmitter( - GpuElementalIrEmitter(hlo_module_config_, - ir_emitter_context_->llvm_module(), - &ir_builder_, GetNestedComputer()) + GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, + GetNestedComputer()) .MakeElementGenerator(random, operand_to_generator), GetIrArray(*random), &ir_builder_) .EmitLoop(IrName(random)); @@ -634,7 +669,7 @@ StatusOr IrEmitter::ComputeNestedElement( tensorflow::gtl::ArraySlice parameter_elements) { llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType( - computation.root_instruction()->shape().element_type(), &ir_builder_), + computation.root_instruction()->shape().element_type(), module_), "return_buffer", &ir_builder_); std::vector parameter_buffers; for (llvm::Value* parameter_element : parameter_elements) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 5e3f3bfdf18bdd5b4f8d0e565d1bb2613cebc3a1..90f40639d583632c4981278ffbfbf53fcf7ba989 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -162,6 +162,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { } IrEmitterContext* ir_emitter_context_; + llvm::Module* module_; // The following fields track the IR emission state. According to LLVM memory // management rules, their memory is owned by the module. @@ -339,8 +340,12 @@ class IrEmitterUnnested : public IrEmitter { // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr BuildGemmThunk(const HloInstruction* inst); - // Returns a CopyThunk that calls host-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildCopyThunk(const HloInstruction* inst); + // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); + + // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildDeviceToDeviceCopyThunk( + const HloInstruction* inst); // Returns an InfeedThunk that performs device-to-device memcpy to implement // `inst`. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 57f010530cc93cf5f2ef60470ce416fe9333a94e..5da1a130d5654b86803396b07a6501c59a182c67 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -52,9 +52,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( io_hlos->push_back(param); const Shape& param_shape = param->shape(); argument_types.push_back( - llvm_ir::ShapeToIrType(param_shape, &ir_builder_)->getPointerTo()); - int64 param_size = llvm_ir::ByteSizeOf( - param_shape, ir_emitter_context_->llvm_module()->getDataLayout()); + llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo()); + int64 param_size = + llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout()); argument_dereferenceable_bytes.push_back(param_size); } { @@ -62,7 +62,7 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( io_hlos->push_back(root); const Shape& root_shape = root->shape(); argument_types.push_back( - llvm_ir::ShapeToIrType(root_shape, &ir_builder_)->getPointerTo()); + llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo()); int64 root_size = llvm_ir::ByteSizeOf( root_shape, ir_emitter_context_->llvm_module()->getDataLayout()); argument_dereferenceable_bytes.push_back(root_size); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 4e6b109b80879108a63a3820dbd1c82b64255f36..1c7e18304df94db836acff953ab7871df403835d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -145,7 +146,7 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { } namespace { -bool ImplementedAsMemcpy(const HloInstruction& hlo) { +bool ImplementedAsHostToDeviceMemcpy(const HloInstruction& hlo) { // `hlo` needs to satisfy three conditions to be implemented as a // host-to-device cuMemcpy. // @@ -156,6 +157,20 @@ bool ImplementedAsMemcpy(const HloInstruction& hlo) { hlo.operand(0)->opcode() == HloOpcode::kConstant && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()); } + +bool ImplementedAsDeviceToDeviceMemcpy( + const BufferAssignment& buffer_assignment, const HloInstruction& hlo) { + // `hlo` needs to satisfy three conditions to be implemented as a + // device-to-device cuMemcpy. + // + // 1. `hlo` is a kCopy instruction. + // 2. `hlo` and its operand have the same shape (thus the same layout too). + // 3. The operand to `hlo` has a buffer assignment (constants do not, for + // instance) which means the source buffer also resides on the device. + return hlo.opcode() == HloOpcode::kCopy && + ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && + buffer_assignment.HasTopLevelAllocation(hlo.operand(0)); +} } // namespace llvm::Function* IrEmitterUnnested::BuildKernelPrototype( @@ -254,46 +269,6 @@ Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution, rhs_instruction, window); } -namespace { - -// Returns the first non-GetTupleElement ancestor instruction of 'hlo'. -// If the first non-GTE ancestor is tuple-shaped, populates 'index' with the -// (possibly nested) tuple indices used on the path from ancestor to 'hlo'. -const HloInstruction* LatestNonGteAncestorAndIndex(const HloInstruction* hlo, - ShapeIndex* index) { - if (hlo->opcode() == HloOpcode::kGetTupleElement) { - const auto* operand = LatestNonGteAncestorAndIndex(hlo->operand(0), index); - index->push_back(hlo->tuple_index()); - return operand; - } - return hlo; -} - -// Checks if we can emit code for DynamicUpdateSlice to update data in-place. -// Returns true if operand 0 of DynamicUpdateSlice and its output buffer -// share the same buffer allocation. -// Returns false otherwise. -bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment, - HloInstruction* fusion) { - CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); - HloInstruction* fused_root = fusion->fused_expression_root(); - if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice) { - return false; - } - // Walk DynamicUpdateSlice operand(0) to fused parameter and get its - // associated operand. See if it shares an allocation with this operand. - ShapeIndex index; - auto* fusion_operand = - LatestNonGteAncestorAndIndex(fused_root->operand(0), &index); - if (fusion_operand->opcode() != HloOpcode::kParameter) { - return false; - } - auto* operand = fusion->operand(fusion_operand->parameter_number()); - return assignment.SharesSliceAtIndex(fusion, {}, operand, index); -} - -} // namespace - Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); // HandleFusion specializes reduction from a multi-dimensional array to a 1D @@ -364,95 +339,40 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { LOG(FATAL) << "Bad opcode for input fusion: " << fusion->fused_expression_root()->opcode(); } - } else if (HloInstruction::FusionKind::kLoop == fusion->fusion_kind() && - root->opcode() == HloOpcode::kDynamicUpdateSlice && - CanUpdateDynamicSliceInPlace( - ir_emitter_context_->buffer_assignment(), fusion)) { - // Loop fusion instruction with DynamicUpdateSlice as fused root. - // DynamicUpdateSlice's operand(0) and 'fusion' output share the same - // BufferAllocation::Slice, so it is safe to emit code to update the slice - // 'in-place'. This avoids copying data outside of the slice update region. + } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace( + fusion, ir_emitter_context_->buffer_assignment())) { + // Fusion node with dynamic-update-slice as the root where the op's input + // (i.e. array to update) shares the same slice as its output. In this case + // we have a special algorithm that modifies the output in place without + // touching the un-updated elements. // Set up kernel thunk and fused ir emitter. thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); - std::vector parameter_arrays; + std::vector operand_arrays; for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand)); + operand_arrays.push_back(GetIrArray(*operand)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), &ir_builder_, GetNestedComputer()); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); - TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - - // Recursively lookup 'fusion_operand' for DynamicUpdateSlice operand 0. - auto* fusion_operand = LatestNonGteAncestor(root->operand(0)); - CHECK_EQ(HloOpcode::kParameter, fusion_operand->opcode()); - - // Operand(0) the input array which shares an allocation with the output. - const auto* input = root->operand(0); - llvm::Value* input_base_ptr = fused_emitter.GetIrValueForGTE(input); - // Operand(1) 'update' is slice with which to update input at operand(0). - const auto* update = root->operand(1); - Shape update_shape = update->shape(); - TF_RETURN_IF_ERROR( - LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape)); - // Operand(2) the dynamic slice indices at which to write 'update'. - const auto* start_indices = root->operand(2); - - // Create element generators for 'update' and 'start_indices'. - llvm_ir::ElementGenerator element_generator = - fused_emitter.GetGenerator(update); - llvm_ir::ElementGenerator start_generator = - fused_emitter.GetGenerator(start_indices); - - // Create loop body emitter which emits code to do the following: - // *) Read dynamic slice start indices into 'start_index'. - // *) Map requested 'index' and slice 'start_index' to input/output shape - // as 'output_index'. - // *) Reads value from 'update' element generator. - // *) Writes value to input/output array at 'output_index'. - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& index) -> Status { - // Emit IR to read dynamic start indices from hlo->operand(2). - const int64 rank = ShapeUtil::Rank(input->shape()); - llvm_ir::IrArray::Index start_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index({ir_builder_.getInt64(i)}); - TF_ASSIGN_OR_RETURN(start_index[i], start_generator(dim_index)); - } - // Calculate 'output_index' at which to write value from update. - llvm_ir::IrArray::Index output_index(rank); - for (int64 i = 0; i < rank; ++i) { - // Emit IR which computes: - // output_index = (start_index + index) % dim_size - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), input->shape().dimensions(i)); - llvm::Value* start_index0 = ir_builder_.CreateZExtOrBitCast( - start_index[i], index[i]->getType()); - output_index[i] = ir_builder_.CreateURem( - ir_builder_.CreateAdd(start_index0, index[i]), dim_size); - } + // Shape of the dynamic-update-slice's "update" operand. + Shape update_shape = root->operand(1)->shape(); - // Read value from 'update'. - TF_ASSIGN_OR_RETURN(llvm::Value * input_value, element_generator(index)); - // Write value to output array. - llvm_ir::IrArray(input_base_ptr, input->shape()) - .EmitWriteArrayElement(output_index, input_value, &ir_builder_); - return Status::OK(); - }; + // Array to write into. Because this is an in-place operation, this is the + // same as operand 0's array. + llvm_ir::IrArray output_array = GetIrArray(*fusion); - // Create loop which iterates over 'update' shape. LaunchDimensions launch_dimensions = CalculateLaunchDimensions( update_shape, ir_emitter_context_->device_description()); CHECK(Thunk::Kind::kKernel == LastThunk()->kind()); UpdateLaunchDimensions(launch_dimensions, static_cast(LastThunk()), ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, update_shape, - launch_dimensions, &ir_builder_) - .EmitLoop(IrName(fusion)); + + return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( + fusion, operand_arrays, output_array, &elemental_emitter, + launch_dimensions, &ir_builder_); } if (ImplementedAsGemm(*fusion)) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); @@ -758,8 +678,13 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, } // namespace Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { - if (ImplementedAsMemcpy(*copy)) { - thunk_sequence_->emplace_back(BuildCopyThunk(copy)); + if (ImplementedAsHostToDeviceMemcpy(*copy)) { + thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy)); + return Status::OK(); + } + if (ImplementedAsDeviceToDeviceMemcpy( + ir_emitter_context_->buffer_assignment(), *copy)) { + thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy)); return Status::OK(); } bool is_transpose_021; @@ -832,8 +757,8 @@ Status IrEmitterUnnested::EmitColumnReduction( auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { // Emit the loop body that reduces one tile. - llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( - input_shape.element_type(), &ir_builder_); + llvm::Type* element_ir_type = + llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); { @@ -1048,7 +973,7 @@ Status IrEmitterUnnested::EmitRowReduction( [=](const llvm_ir::IrArray::Index& tile_index) -> Status { // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( - input_shape.element_type(), &ir_builder_); + input_shape.element_type(), ir_emitter_context_->llvm_module()); llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); { @@ -1435,7 +1360,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // boolean flag if the value is initialized. The initialized_flag is set // false. llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(operand_element_type, + ir_emitter_context_->llvm_module()), "selected_value_address", &ir_builder_); llvm::Value* selected_index_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( @@ -1515,7 +1441,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &ir_builder_); llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(PRED, + ir_emitter_context_->llvm_module()), "select_return_buffer", &ir_builder_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *select_and_scatter->select(), @@ -1525,8 +1452,10 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. llvm::Value* cond = ir_builder_.CreateICmpNE( - result, llvm::ConstantInt::get( - llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0), + result, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( + PRED, ir_emitter_context_->llvm_module()), + 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); @@ -1625,7 +1554,7 @@ llvm::Function* IrEmitterUnnested::EmitBasePointersForHloAndItsOperands( // with their operand buffer in 'io_hlos' and 'non_io_hlos' below. std::vector non_io_hlos; for (const HloInstruction* operand : hlo.operands()) { - const HloInstruction* to_lookup = LatestNonGteAncestor(operand); + const HloInstruction* to_lookup = operand->LatestNonGteAncestor(); if (buffer_assignment.HasTopLevelAllocation(to_lookup) && buffer_assignment.GetUniqueTopLevelSlice(to_lookup) .ConsumeValueOrDie() @@ -1665,7 +1594,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( std::vector io_buffers; io_buffers.reserve(io_hlos.size()); for (const HloInstruction* io_hlo : io_hlos) { - io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo))); + io_buffers.push_back(GetAllocationSlice(*io_hlo->LatestNonGteAncestor())); } // Create a KernelThunk that launches the kernel that implements "inst". @@ -1673,11 +1602,11 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm_ir::AsString(kernel->getName()), inst); } -std::unique_ptr IrEmitterUnnested::BuildCopyThunk( +std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); - return MakeUnique( + return MakeUnique( /*source_address=*/operand->literal().InternalData(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -1686,6 +1615,18 @@ std::unique_ptr IrEmitterUnnested::BuildCopyThunk( inst); } +std::unique_ptr IrEmitterUnnested::BuildDeviceToDeviceCopyThunk( + const HloInstruction* inst) { + const HloInstruction* operand = inst->operand(0); + return MakeUnique( + /*source_address=*/GetAllocationSlice(*operand), + /*destination_buffer=*/GetAllocationSlice(*inst), + /*mem_size=*/ + llvm_ir::ByteSizeOf(operand->shape(), + ir_emitter_context_->llvm_module()->getDataLayout()), + inst); +} + std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( const HloInstruction* inst) { CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); @@ -1940,7 +1881,8 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_); + llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_, + module_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc index b0480e2f475c7b295b88d36a91ae08a90b818085..0bbd63fb7bfc657cb7bb1de673253c198f5bd25f 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc @@ -84,8 +84,8 @@ Status GpuLayoutAssignment::AddBackendConstraints( --i) { input_layout.push_back(dimension_numbers.spatial_dimensions(i)); } - input_layout.push_back(dimension_numbers.feature_dimension()); - input_layout.push_back(dimension_numbers.batch_dimension()); + input_layout.push_back(dimension_numbers.input_feature_dimension()); + input_layout.push_back(dimension_numbers.input_batch_dimension()); Shape input_shape(input->shape()); *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); @@ -106,8 +106,8 @@ Status GpuLayoutAssignment::AddBackendConstraints( --i) { output_layout.push_back(dimension_numbers.spatial_dimensions(i)); } - output_layout.push_back(dimension_numbers.feature_dimension()); - output_layout.push_back(dimension_numbers.batch_dimension()); + output_layout.push_back(dimension_numbers.output_feature_dimension()); + output_layout.push_back(dimension_numbers.output_batch_dimension()); Shape output_shape(output->shape()); *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index af853385d634b06d31cef94216fb4059dfcadc3d..79493c4112804f8454d200f3f83aa85d718f0d0a 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -39,6 +39,8 @@ message HloInstructionProto { string name = 1; string opcode = 2; xla.Shape shape = 3; + + // TODO(b/67782397): Replace instruction names with HloInstruction ids. repeated string operand_names = 4; repeated string control_predecessor_names = 5; repeated string called_computation_names = 6; @@ -58,6 +60,64 @@ message HloInstructionProto { // Index for kGetTupleElement. int64 tuple_index = 13; + + // Dimensions present for some operations that require reshaping or + // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. + repeated int64 dimensions = 14; + + // Describes the window in a windowed operation such as convolution. + xla.Window window = 15; + + // Describes the dimension numbers used for a convolution. + xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; + + // Describes the [begin, end) index range and stride for slices. + message SliceDimensions { + int64 start = 1; + int64 limit = 2; + int64 stride = 3; + } + repeated SliceDimensions slice_dimensions = 17; + + // The bit sizes for a reduce-precision operation. + int32 exponent_bits = 18; + int32 mantissa_bits = 19; + + // Describes the [start, start + size) range size for a dynamic slice + // ('start' is specified dynamically in the second operand of the operation). + repeated int64 dynamic_slice_sizes = 20; + + // The padding configuration that describes the edge padding and interior + // padding of this pad instruction. Only set for pad instructions. + xla.PaddingConfig padding_config = 21; + + // Outfeed configuration information, only present for kOutfeed. + bytes outfeed_config = 22; + + // The distribution requested for random number generation. + // Only present for kRng. + xla.RandomDistribution distribution = 23; + + // A small float number added to the variance to avoid divide-by-zero error. + // Only present for kBatchNormTraining. + float epsilon = 24; + + // An integer value representing the index of the feature dimension. + // Only present for kBatchNormTraining. + int64 feature_index = 25; + + // Represents a unique identifier for each Send/Recv instruction pair. + // Only present for kSend or kRecv. + int64 channel_id = 26; + + // The string representation of the infeed configuration. + bytes infeed_config = 27; + + // Name of a global symbol to call, only present for kCustomCall. + string custom_call_target = 28; + + // Shape of outfeed request. + xla.Shape outfeed_shape = 29; } // Serialization of HloComputation. @@ -67,6 +127,9 @@ message HloComputationProto { // The array of instructions is always in a valid dependency order, where // operands appear before their users. repeated HloInstructionProto instructions = 2; + + // The name of the root of the computation. + string root_name = 3; } // Serialization of HloModule. @@ -187,3 +250,7 @@ message HloProto { HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } + +message HloProtos { + repeated HloProto hlo_protos = 1; +} diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 444104d88fe34600e7dc38fc806ea34b74660da8..2285518a0e96128436d42bd12ee2fb31142f1ef1 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -373,13 +373,14 @@ string HloComputation::ToString(int nested_level) const { for (int i = 0; i < nested_level; i++) { s << " "; } - s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) - << " { \n"; + s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) + << " {\n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { for (int i = 0; i < nested_level; i++) { s << " "; } - s << " " << instruction->ToString() << "\n"; + s << " " << (instruction == root_instruction_ ? "ROOT " : "") + << instruction->ToString() << "\n"; if (instruction->opcode() == HloOpcode::kFusion) { s << instruction->fused_instructions_computation()->ToString( nested_level + 1) @@ -400,9 +401,38 @@ HloComputationProto HloComputation::ToProto() const { HloInstructionProto instruction_proto = instruction->ToProto(); proto.add_instructions()->Swap(&instruction_proto); } + proto.set_root_name(root_instruction()->name()); return proto; } +/* static */ StatusOr> +HloComputation::CreateFromProto( + HloModule* module, const HloComputationProto& proto, + tensorflow::gtl::FlatMap* computation_map, + HloInstruction* fusion_instruction) { + std::vector> instructions; + tensorflow::gtl::FlatMap instruction_map; + int64 parameter_count = 0; + for (const HloInstructionProto& instruction_proto : proto.instructions()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr instruction, + HloInstruction::CreateFromProto(module, instruction_proto, + instruction_map, computation_map)); + if (instruction->opcode() == HloOpcode::kParameter) { + parameter_count++; + } + TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name())); + instruction_map[instruction->name()] = instruction.get(); + instructions.push_back(std::move(instruction)); + } + + TF_RET_CHECK(!proto.root_name().empty()); + TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name())); + HloInstruction* root = instruction_map.at(proto.root_name()); + return WrapUnique(new HloComputation( + proto.name(), parameter_count, &instructions, root, fusion_instruction)); +} + void HloComputation::FuseInstructionsInto( tensorflow::gtl::ArraySlice instructions_to_fuse, HloInstruction* fusion_instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index b929b41bad403e76948e00fa627829565ed324fd..f4edd175016ee30d31cc0cad6bdbd3eaa014c704 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -143,6 +143,22 @@ class HloComputation { // Returns a serialized representation of this computation. HloComputationProto ToProto() const; + // Creates a computation from the given proto. Arguments: + // + // module: the module which will contain the computation. The newly created + // computation is *not* added to the module, however. + // proto: the proto to convert from. + // computation_map: a map from computation name to HloComputation*. This map + // must contain all computations which the newly constructed computation + // calls. + // fusion_instruction: if non-null then the newly created computation will be + // constructed as a fused computation with this instruction as its fusion + // parent. + static StatusOr> CreateFromProto( + HloModule* module, const HloComputationProto& proto, + tensorflow::gtl::FlatMap* computation_map, + HloInstruction* fusion_instruction = nullptr); + // Gets the instructions in this computation. // // The returned type is a range of HloInstruction*s, so you can iterate over @@ -296,8 +312,7 @@ class HloComputation { explicit HloComputation( const string& name, int parameter_count, std::vector>* instructions, - HloInstruction* root_instruction, - HloInstruction* fusion_instruction = nullptr); + HloInstruction* root_instruction, HloInstruction* fusion_instruction); // Internal helper for adding instructions. HloInstruction* AddInstructionInternal( @@ -343,11 +358,6 @@ class HloComputation { std::vector param_instructions_; - // Unique name generator for instruction identifiers. Instruction names should - // be unique per computation and this is enforced when instructions are added - // to the computation. - NameUniquer instruction_name_uniquer_; - TF_DISALLOW_COPY_AND_ASSIGN(HloComputation); }; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index ccab7bf34862f3303db1331a87b5c70fdc3283ba..7b7588f4ba9aa622677db6f9d5022cc8cc029e04 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -310,7 +310,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { } TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { - // Test that DeepCopyInstruction properly copies elements of a a tuple as + // Test that DeepCopyInstruction properly copies elements of a tuple as // specified by the given indices. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index b30c7b417f3785bd485f17d7f46a8b47ef4d4b58..53450991b6fad5b9651d9d23b55c908e6b68e5dd 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -49,8 +49,8 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } // Skip Constant, Parameter, Reduce operation. - // TODO(b/35975797): Enable Reduce operation once arbitary computation are - // supported by the evaluator. + // TODO(b/35975797): Enable Reduce operation once arbitrary computation + // are supported by the evaluator. // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. if (instruction->opcode() == HloOpcode::kParameter || instruction->opcode() == HloOpcode::kConstant || @@ -63,8 +63,8 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } - // Broadcasts dramatically increase the size of constants with is often - // detrimental to performance and memory capacity so do not fold + // Broadcasts dramatically increase the size of constants, which is often + // detrimental to performance and memory capacity, so do not fold // broadcasts. if (instruction->opcode() == HloOpcode::kBroadcast) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 65725ca692fb3429106f5ed50f4a2c11bd46f54c..ca99fd6de8413b0511827f8a178c90391968de8e 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -393,12 +393,14 @@ Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution, const Window& window) { const auto& dnums = convolution->convolution_dimension_numbers(); const int64 output_features = - convolution->shape().dimensions(dnums.feature_dimension()); + convolution->shape().dimensions(dnums.output_feature_dimension()); // For each output element, we do one fma per element in the kernel at some // given output feature index. const int64 fmas_per_output_element = - ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features; + output_features > 0 + ? ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features + : 0; const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape()); current_properties_[kFlopsKey] = output_elements * fmas_per_output_element * kFmaFlops; diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 71321e5e9ae733ea06d7988e2a301bf838022563..a4921232f5848dbe1789c4c641e2b0ba3c1848bb 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -64,6 +64,29 @@ StatusOr HloDCE::Run(HloModule* module) { } } + // Now DCE HloComputations. First, collect the computations that are + // referenced by some remaining instruction. + std::unordered_set live_computations; + if (HloComputation* entry_computation = module->entry_computation()) { + live_computations.insert(entry_computation); + } + for (auto* computation : module->MakeComputationPostOrder()) { + for (auto* instruction : computation->instructions()) { + for (auto* subcomp : instruction->called_computations()) { + live_computations.insert(subcomp); + } + } + } + + // Remove dead computations. + std::list computations = module->MakeComputationPostOrder(); + for (auto* computation : computations) { + if (live_computations.count(computation) == 0) { + TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); + changed = true; + } + } + return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index fca3fa0f58b7c5929c6ffa6c2d8ae6f76660b380..4e244494d6f98c48f4376bd762f116b9a9c2084d 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -24,10 +24,15 @@ limitations under the License. namespace xla { -// HLO pass which removes all dead instructions from each computation in the -// module. An instruction is dead if it is not reachable from the root. This -// pass does not remove dead parameter instructions as parameter instructions -// cannot be deleted, nor does the pass remove dead computations. +// HLO pass which removes dead instructions from each computation in the module +// and removes dead computations from the module. +// +// An instruction is dead if it is not reachable from the root. A computation is +// dead if it is not the entry computation of the module and it is not reachable +// from the entry computation. +// +// This pass does not remove dead parameter instructions, as parameter +// instructions cannot be deleted. class HloDCE : public HloPassInterface { public: ~HloDCE() override {} diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index fa0ab98649abf0c452ca695af34d1c6f6ce116f5..d54b9a27087a42fd23eab0bd06e8deaca567312b 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -299,5 +299,93 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { EXPECT_TRUE(HasInstruction(*computation, live_call)); } +TEST_F(HloDceTest, RemoveDeadSubcomputation) { + auto module = CreateNewModule(); + HloComputation::Builder builder(TestName()); + + HloComputation::Builder subcomp_builder("reduction_subcomp"); + { + auto* param0 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0")); + auto* param1 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1")); + subcomp_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1)); + } + auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build()); + + // Create a dead reduce instruction. + builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(F32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{0}, reduce_subcomp)); + + // Add another instruction as the root of the computation. + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))); + + module->AddEntryComputation(builder.Build()); + EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); + + HloDCE dce; + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + + // We should have DCE'ed the reduction computation along with the reduction + // instruction. + EXPECT_EQ(module->MakeComputationPostOrder().size(), 1); +} + +TEST_F(HloDceTest, KeepUsedSubcomputation) { + auto module = CreateNewModule(); + HloComputation::Builder builder(TestName()); + + HloComputation::Builder subcomp_builder("reduction_subcomp"); + { + auto* param0 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0")); + auto* param1 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1")); + subcomp_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1)); + } + auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build()); + + // Create a dead reduce instruction. + builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(F32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{0}, reduce_subcomp)); + + // Add another instruction as the root of the computation that also uses + // reduce_subcomp. + builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(F32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{0}, reduce_subcomp)); + + module->AddEntryComputation(builder.Build()); + EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); + + HloDCE dce; + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + + // We shouldn't have DCE'ed reduce_subcomp, even though we removed one of + // its users. + EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 4f9d6c00961d3027d07f2581ca410f88b6b2dad8..f4a2c3d0e88e32c8352a4da73ff8c06e33482985 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -50,6 +50,12 @@ namespace xla { namespace { +template +struct is_complex_t : public std::false_type {}; + +template <> +struct is_complex_t : public std::true_type {}; + template StatusOr> Compare(const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, @@ -101,6 +107,37 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, return std::move(result); } +template <> +StatusOr> Compare( + const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, + const Literal& rhs_literal) { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](complex64 lhs_el, complex64 rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](complex64 lhs_el, complex64 rhs_el) { + return lhs_el != rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + auto result = Literal::CreateFromShape(shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); + + return std::move(result); +} + template StatusOr> ElementWiseUnaryOpImpl( HloInstruction* instruction, @@ -138,7 +175,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", HloOpcodeString(hlo_instruction->opcode()).c_str()); - }; + } // TODO(b/35950897): many of the stl functions used in the handlers are not // overloaded for every XLA primitive types. @@ -156,7 +193,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, - typename std::enable_if::value>::type* = nullptr> + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], ElementWiseUnaryOp(abs, [](NativeT elem_operand) { @@ -169,7 +207,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleAbs(abs, operand); } - Status HandleRound(HloInstruction* round) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[round], ElementWiseUnaryOp(round, [](ReturnT elem_operand) { return std::round(elem_operand); @@ -177,6 +218,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + return InvalidArgument("Unsupported type for Round"); + } + + Status HandleRound(HloInstruction* round) override { + return HandleRound(round); + } + Status HandleBroadcast(HloInstruction* broadcast) override { parent_->evaluated_[broadcast] = Literal::CreateFromShape(broadcast->shape()); @@ -205,15 +257,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } return operand_to_broadcast.Get(broadcast_indices); }); - }; + } - Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) { return std::ceil(elem_operand); })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + return InvalidArgument("Unsupported type for Ceil"); + } + + Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override { + return HandleCeil(ceil); + } Status HandleConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); @@ -237,15 +303,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return std::exp(elem_operand); })); return Status::OK(); - }; + } - Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor], ElementWiseUnaryOp(floor, [](ReturnT elem_operand) { return std::floor(elem_operand); })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Floor"); + } + + Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override { + return HandleFloor(floor); + } Status HandleLog(HloInstruction* log, HloInstruction* operand) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], @@ -253,16 +333,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return std::log(elem_operand); })); return Status::OK(); - }; + } - Status HandleLogicalNot(HloInstruction* logical_not, - HloInstruction* operand) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[logical_not], - ElementWiseUnaryOp(logical_not, - [](ReturnT elem_operand) { return !elem_operand; })); + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + return !elem_operand; + })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + return InvalidArgument("Unsupported type for Not"); + } + + Status HandleNot(HloInstruction* not_, HloInstruction* operand) override { + return HandleNot(not_); + } Status HandleNegate(HloInstruction* negate, HloInstruction* operand) override { @@ -271,16 +364,36 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return -elem_operand; })); return Status::OK(); - }; + } - Status HandleSign(HloInstruction* sign, HloInstruction* operand) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { return (ReturnT(0) < elem_operand) - (elem_operand < ReturnT(0)); })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { + auto abs_val = std::abs(elem_operand); + return 0 == abs_val ? ReturnT(0) + : elem_operand / abs_val; + })); + return Status::OK(); + } + + Status HandleSign(HloInstruction* sign, HloInstruction* operand) override { + return HandleSign(sign); + } Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], @@ -288,7 +401,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return std::tanh(elem_operand); })); return Status::OK(); - }; + } Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, HloInstruction* rhs) override { @@ -298,7 +411,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return lhs_elem * rhs_elem; })); return Status::OK(); - }; + } Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, HloInstruction* rhs) override { @@ -308,7 +421,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return lhs_elem - rhs_elem; })); return Status::OK(); - }; + } Status HandleAdd(HloInstruction* add, HloInstruction* lhs, HloInstruction* rhs) override { @@ -318,7 +431,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return lhs_elem + rhs_elem; })); return Status::OK(); - }; + } Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, HloInstruction* rhs) override { @@ -328,25 +441,53 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return lhs_elem / rhs_elem; })); return Status::OK(); - }; + } - Status HandleMaximum(HloInstruction* maximum) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { return std::fmax(lhs, rhs); })); return Status::OK(); - }; + } - Status HandleMinimum(HloInstruction* minimum) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + return InvalidArgument("Unsupported type for Maximum"); + } + + Status HandleMaximum(HloInstruction* maximum) override { + return HandleMaximum(maximum); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[minimum], ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { return std::fmin(lhs_el, rhs_el); })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + return InvalidArgument("Unsupported type for Minimum"); + } + + Status HandleMinimum(HloInstruction* minimum) override { + return HandleMinimum(minimum); + } Status HandlePower(HloInstruction* power, HloInstruction* lhs, HloInstruction* rhs) override { @@ -356,40 +497,172 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return std::pow(lhs_el, rhs_el); })); return Status::OK(); - }; + } - Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, - HloInstruction* rhs) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[remainder], ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) { return std::fmod(lhs_el, rhs_el); })); return Status::OK(); - }; + } - Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, - HloInstruction* rhs) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + return InvalidArgument("Unsupported type for Remainder"); + } + + Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleRemainder(remainder); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { TF_ASSIGN_OR_RETURN( - parent_->evaluated_[logical_and], - ElementWiseBinaryOp(logical_and, [](ReturnT lhs_el, ReturnT rhs_el) { + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) { return lhs_el && rhs_el; })); return Status::OK(); - }; + } - Status HandleLogicalOr(HloInstruction* logical_or, HloInstruction* lhs, - HloInstruction* rhs) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + return InvalidArgument("Unsupported type for And"); + } + + Status HandleAnd(HloInstruction* and_, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleAnd(and_); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { TF_ASSIGN_OR_RETURN( - parent_->evaluated_[logical_or], - ElementWiseBinaryOp(logical_or, [](ReturnT lhs_el, ReturnT rhs_el) { + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) { return lhs_el || rhs_el; })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + return InvalidArgument("Unsupported type for Or"); + } + + Status HandleOr(HloInstruction* or_, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleOr(or_); + } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shl], + ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { + return lhs_elem << rhs_elem; + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftLeft"); + } + + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftLeft(shl, lhs, rhs); + } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + typedef typename std::make_signed::type SignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + return static_cast(static_cast(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + } + + Status HandleShiftRightArithmetic(HloInstruction* shra, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftRightArithmetic(shra, lhs, rhs); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + typedef typename std::make_unsigned::type UnsignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + return static_cast(static_cast(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftRightLogical"); + } + + Status HandleShiftRightLogical(HloInstruction* shrl, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftRightLogical(shrl, lhs, rhs); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> Status HandleClamp(HloInstruction* clamp, HloInstruction* min, - HloInstruction* arg, HloInstruction* max) override { + HloInstruction* arg, HloInstruction* max) { std::function clamp_op = [](ReturnT low, ReturnT high, ReturnT value) { return std::fmax(low, std::fmin(value, high)); @@ -397,7 +670,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], ElementWiseTernaryOp(clamp, std::move(clamp_op))); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleClamp(HloInstruction* clamp, HloInstruction* min, + HloInstruction* arg, HloInstruction* max) { + return InvalidArgument("Unsupported type for Clamp"); + } + + Status HandleClamp(HloInstruction* clamp, HloInstruction* min, + HloInstruction* arg, HloInstruction* max) override { + return HandleClamp(clamp, min, arg, max); + } Status HandleSelect(HloInstruction* select, HloInstruction* pred, HloInstruction* on_true, @@ -413,7 +699,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], ElementWiseTernaryOp(select, std::move(select_op))); return Status::OK(); - }; + } Status HandleReverse(HloInstruction* reverse, HloInstruction* operand) override { @@ -443,7 +729,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[reverse] = std::move(result); return Status::OK(); - }; + } Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override { @@ -461,7 +747,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const auto& dnums = conv->convolution_dimension_numbers(); const int64 num_spatial_dims = dnums.spatial_dimensions_size(); CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); - CHECK_GE(num_spatial_dims, 1); + CHECK_GE(num_spatial_dims, 0); CHECK_EQ(window.dimensions_size(), num_spatial_dims); const auto lhs_rank = ShapeUtil::Rank(lhs_shape); @@ -481,14 +767,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - // Dimension number applicable for both input (lhs), and output. - const int64 batch_dim = dnums.batch_dimension(); - const int64 z_dim = dnums.feature_dimension(); + // Dimension number applicable for input (lhs). + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_z_dim = dnums.input_feature_dimension(); // Dimension number applicable for kernel (rhs). const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); + // Dimension number applicable for output. + const int64 output_batch_dim = dnums.output_batch_dimension(); + const int64 output_z_dim = dnums.output_feature_dimension(); - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, z_dim); + const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); std::vector window_dimension_sizes; for (auto i : dnums.kernel_spatial_dimensions()) { @@ -509,13 +798,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::fill(rhs_index.begin(), rhs_index.end(), 0); std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0); - lhs_index[batch_dim] = out_index[batch_dim]; - rhs_index[kernel_output_z_dim] = out_index[z_dim]; + lhs_index[input_batch_dim] = out_index[output_batch_dim]; + rhs_index[kernel_output_z_dim] = out_index[output_z_dim]; // Convolve input feature with kernel. do { for (int64 iz = 0; iz < z_size; ++iz) { - lhs_index[z_dim] = iz; + lhs_index[input_z_dim] = iz; rhs_index[kernel_input_z_dim] = iz; // Find corresponding spatial dimension index for input (lhs). @@ -563,7 +852,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[conv] = std::move(result); return Status::OK(); - }; + } Status HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) override { @@ -630,7 +919,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[dot] = std::move(result); return Status::OK(); - }; + } Status HandlePad(HloInstruction* pad) override { CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); @@ -699,7 +988,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[pad] = std::move(result); return Status::OK(); - }; + } Status HandleDynamicSlice(HloInstruction* dynamic_slice, HloInstruction* operand, @@ -752,7 +1041,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } return Status::OK(); - }; + } Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, HloInstruction* operand, @@ -808,7 +1097,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } return Status::OK(); - }; + } Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, @@ -896,7 +1185,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[reduce] = std::move(result); return Status::OK(); - }; + } Status HandleReduceWindow(HloInstruction* reduce_window, HloInstruction* operand, const Window& window, @@ -983,7 +1272,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[reduce_window] = std::move(result); return Status::OK(); - }; + } Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override { const Shape& shape = slice->shape(); @@ -1012,7 +1301,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_RETURN_IF_ERROR(result->Populate(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); - }; + } private: template @@ -1155,32 +1444,33 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator* parent_; -}; // namespace xla +}; // class HloEvaluator::TypedVisitor HloEvaluator::HloEvaluator() { typed_visitors_[PRED] = MakeUnique>(this); typed_visitors_[U8] = MakeUnique>(this); typed_visitors_[U16] = MakeUnique([](HloInstruction*) { - return Unimplemented("unhandled primitive type: U16."); + return Unimplemented("HloEvaluator: unhandled primitive type: U16."); }); typed_visitors_[U32] = MakeUnique>(this); typed_visitors_[U64] = MakeUnique>(this); typed_visitors_[S8] = MakeUnique>(this); typed_visitors_[S16] = MakeUnique([](HloInstruction*) { - return Unimplemented("unhandled primitive type: S16."); + return Unimplemented("HloEvaluator: unhandled primitive type: S16."); }); typed_visitors_[S32] = MakeUnique>(this); typed_visitors_[S64] = MakeUnique>(this); typed_visitors_[F16] = MakeUnique([](HloInstruction*) { - return Unimplemented("unhandled primitive type: F16."); + return Unimplemented("HloEvaluator: unhandled primitive type: F16."); }); typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); + typed_visitors_[C64] = MakeUnique>(this); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { - return Unimplemented("unhandled primitive type: TUPLE."); + return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); }); typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { - return Unimplemented("unhandled primitive type: OPAQUE."); + return Unimplemented("HloEvaluator: unhandled primitive type: OPAQUE."); }); } @@ -1241,8 +1531,14 @@ StatusOr> HloEvaluator::Evaluate( StatusOr> HloEvaluator::Evaluate( HloInstruction* instruction) { - TF_RET_CHECK(hlo_query::AllOperandsAreConstants(*instruction)); - TF_RET_CHECK(instruction->opcode() != HloOpcode::kParameter); + if (instruction->opcode() == HloOpcode::kParameter) { + return tensorflow::errors::FailedPrecondition( + "Cannot evaluate a parameter."); + } + if (!hlo_query::AllOperandsAreConstants(*instruction)) { + return tensorflow::errors::FailedPrecondition( + "Not all operands are constants."); + } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); arg_literals_.clear(); @@ -1285,8 +1581,17 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( operands.push_back(operand.get()); } - return Evaluate( - instruction->CloneWithNewOperands(instruction->shape(), operands).get()); + std::unique_ptr cloned_instruction = + instruction->CloneWithNewOperands(instruction->shape(), operands); + auto result = Evaluate(cloned_instruction.get()); + + // Clean up our cloned instructions before returning. + cloned_instruction->DetachFromOperands(); + for (auto& operand : owned_operands) { + operand->DetachFromOperands(); + } + + return result; } Status HloEvaluator::HandleParameter(HloInstruction* parameter) { @@ -1466,6 +1771,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; + case C64: { + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), opcode, + lhs_literal, rhs_literal)); + } break; default: LOG(FATAL) << "HandleCompare: unknown primitive type: " << PrimitiveType_Name(lhs->shape().element_type()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index a8a73e866ee08600dcdf58d7618b30514a2b4ca1..85477af6fe26f53504c07204348566c16a24392c 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -41,7 +41,7 @@ limitations under the License. namespace xla { namespace { -class HloEvaluatorTest : public HloTestBase { +class HloEvaluatorTest : public HloVerifiedTestBase { protected: HloEvaluatorTest() { evaluator_ = MakeUnique(); } @@ -62,8 +62,7 @@ TEST_F(HloEvaluatorTest, DoesClamp) { auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto instruction = b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); @@ -89,8 +88,7 @@ TEST_F(HloEvaluatorTest, DoesSelect) { b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false))); auto instruction = b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); @@ -112,8 +110,7 @@ TEST_F(HloEvaluatorTest, DoesAdd) { auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); auto instruction = b.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1, c2)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); @@ -125,111 +122,100 @@ TEST_F(HloEvaluatorTest, DoesAdd) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. -TEST_F(HloEvaluatorTest, DoesDivide) { - { - auto lhs_s64 = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs_s64 = Literal::CreateR2({{2, 4}, {4, 4}}); - - Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1_s64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_s64))); - auto c2_s64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_s64))); - auto instruction = b.AddInstruction(HloInstruction::CreateBinary( - shape_s64, HloOpcode::kDivide, c1_s64, c2_s64)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - - auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); - } - { - auto lhs_f64 = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); - auto rhs_f64 = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); - - Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1_f64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_f64))); - auto c2_f64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_f64))); - auto instruction = b.AddInstruction(HloInstruction::CreateBinary( - shape_f64, HloOpcode::kDivide, c1_f64, c2_f64)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - - auto expected = - Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); - } +TEST_F(HloEvaluatorTest, DoesDivideInt64) { + auto lhs_s64 = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs_s64 = Literal::CreateR2({{2, 4}, {4, 4}}); + + Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); + HloComputation::Builder b(TestName()); + auto c1_s64 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_s64))); + auto c2_s64 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_s64))); + auto instruction = b.AddInstruction(HloInstruction::CreateBinary( + shape_s64, HloOpcode::kDivide, c1_s64, c2_s64)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} +TEST_F(HloEvaluatorTest, DoesDivideDouble) { + auto lhs_f64 = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs_f64 = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); + + Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); + HloComputation::Builder b(TestName()); + auto c1_f64 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_f64))); + auto c2_f64 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_f64))); + auto instruction = b.AddInstruction(HloInstruction::CreateBinary( + shape_f64, HloOpcode::kDivide, c1_f64, c2_f64)); + module().AddEntryComputation(b.Build()); + + auto result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + auto expected = + Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); + + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. -TEST_F(HloEvaluatorTest, DoesAbs) { - { - auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); - const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = b.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - - auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); - } +TEST_F(HloEvaluatorTest, DoesAbsR2) { + auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); + const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); + HloComputation::Builder b(TestName()); + auto c1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); + auto instruction = + b.AddInstruction(HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1)); + module().AddEntryComputation(b.Build()); + std::unique_ptr result = + evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} +TEST_F(HloEvaluatorTest, DoesAbsR0) { // For R0 literal. - { - const Shape& r0 = ShapeUtil::MakeShape(F32, {}); - auto operand = Literal::CreateR0(-1.0f); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = - b.AddInstruction(HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); - auto expected = Literal::CreateR0(1.0f); - - LiteralTestUtil::ExpectEqual(*expected, *result); - } + const Shape& r0 = ShapeUtil::MakeShape(F32, {}); + auto operand = Literal::CreateR0(-1.0f); + HloComputation::Builder b(TestName()); + auto c1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); + auto instruction = + b.AddInstruction(HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1)); + module().AddEntryComputation(b.Build()); + + auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); + auto expected = Literal::CreateR0(1.0f); + LiteralTestUtil::ExpectEqual(*expected, *result); +} +TEST_F(HloEvaluatorTest, DoesAbsR1WithZeroSize) { // For R1 literal with dimension of size 0. - { - Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); - auto operand = Literal::CreateR1({}); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = b.AddInstruction( - HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); - auto expected = Literal::CreateR1({}); - - LiteralTestUtil::ExpectEqual(*expected, *result); - } -} // namespace + Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); + auto operand = Literal::CreateR1({}); + HloComputation::Builder b(TestName()); + auto c1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); + auto instruction = b.AddInstruction( + HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1)); + module().AddEntryComputation(b.Build()); + + auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); + auto expected = Literal::CreateR1({}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. @@ -253,8 +239,7 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2")); b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs_instruction, param_rhs2)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, args).ConsumeValueOrDie(); @@ -279,8 +264,7 @@ TEST_F(HloEvaluatorTest, DoesReshape) { const int64 permutation[] = {1, 2, 0, 4, 3}; b.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -303,8 +287,7 @@ TEST_F(HloEvaluatorTest, DoesBroadcast) { HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( output_literal->shape(), literal_instruction, {1, 2})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -324,8 +307,7 @@ TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { b.AddInstruction(HloInstruction::CreateBroadcast( output_literal->shape(), literal_instruction, /*broadcast_dimensions=*/{})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -343,11 +325,10 @@ TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { std::vector operands = {operand1, operand2}; - Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + Shape shape = ShapeUtil::MakeShape(S64, {4, 2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -370,8 +351,7 @@ TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { Shape shape = ShapeUtil::MakeShape(S64, {2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -392,8 +372,7 @@ TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -414,8 +393,7 @@ TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -451,8 +429,7 @@ TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { Shape shape = ShapeUtil::MakeShape(S32, {5, 2}); auto pad_instruction = b.AddInstruction(HloInstruction::CreatePad( shape, operand_instruction, padding_value_instruction, padding_config)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); auto result = evaluator_->Evaluate(pad_instruction).ConsumeValueOrDie(); @@ -479,8 +456,7 @@ TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}}); b.AddInstruction(HloInstruction::CreatePad( shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -525,8 +501,7 @@ TEST_F(HloEvaluatorTest, NegativePadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -572,8 +547,7 @@ TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -609,8 +583,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -653,8 +626,7 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) { Shape shape = ShapeUtil::MakeShape(F32, {2}); b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -695,8 +667,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -736,8 +707,10 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.set_kernel_output_feature_dimension(0); @@ -747,8 +720,7 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -803,8 +775,7 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -868,8 +839,10 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(2); - dnums.set_feature_dimension(0); + dnums.set_input_batch_dimension(2); + dnums.set_output_batch_dimension(2); + dnums.set_input_feature_dimension(0); + dnums.set_output_feature_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(3); @@ -881,8 +854,7 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -940,8 +912,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1005,8 +976,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1078,8 +1048,7 @@ TEST_F(HloEvaluatorTest, const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1127,15 +1096,14 @@ TEST_F(HloEvaluatorTest, ReduceAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - HloModule module(TestName()); - auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + auto add_func = module().AddEmbeddedComputation(add_computation.Build()); Shape shape = ShapeUtil::MakeShape(F32, {2}); b.AddInstruction( HloInstruction::CreateReduce(shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1}, add_func)); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1171,8 +1139,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowMax) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); max_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); - HloModule module(TestName()); - auto max_func = module.AddEmbeddedComputation(max_computation.Build()); + auto max_func = module().AddEmbeddedComputation(max_computation.Build()); Window window; WindowDimension dim; @@ -1189,7 +1156,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowMax) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1223,8 +1190,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - HloModule module(TestName()); - auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + auto add_func = module().AddEmbeddedComputation(add_computation.Build()); Window window; WindowDimension dim; @@ -1247,7 +1213,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1277,8 +1243,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - HloModule module(TestName()); - auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + auto add_func = module().AddEmbeddedComputation(add_computation.Build()); Window window; @@ -1309,7 +1274,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1340,8 +1305,7 @@ TEST_F(HloEvaluatorTest, StridedSlice) { /*start_indices=*/{0, 2}, /*limit_indices=*/{3, 5}, /*strides=*/{2, 3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1375,8 +1339,7 @@ TEST_F(HloEvaluatorTest, DynamicSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1412,8 +1375,7 @@ TEST_F(HloEvaluatorTest, DynamicSliceModSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1450,8 +1412,7 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, operand, update, start_indices)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1487,8 +1448,7 @@ TEST_F(HloEvaluatorTest, SetAndGetTuples) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1527,8 +1487,7 @@ TEST_F(HloEvaluatorTest, SetAndGetNestedTuples) { b.AddInstruction( HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1568,8 +1527,7 @@ TEST_F(HloEvaluatorTest, Reverse) { const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1}); b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 9b4a2f1048cb0644e6ba81e4e13115b608e4fcc0..e000a0670614d09976c74efd49948267191bbf79 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -231,9 +231,9 @@ string HtmlLikeStringSanitize(tensorflow::StringPiece s) { // commutative, we also support them with param0 and param1 swapped. // // This is useful primarily for reduce and map nodes. These take a -// subcomputation which is almost always one of the four above, and pattern -// matching it to a short string lets us tell the user what the subcomputation -// is without drawing it as a graph. +// subcomputation which is almost always one of the above, and pattern matching +// it to a short string lets us tell the user what the subcomputation is without +// drawing it as a graph. optional MatchTrivialComputation(const HloComputation* computation) { if (computation->instruction_count() != 3) { return nullopt; @@ -342,6 +342,11 @@ class HloDotDumper { bool ShouldShowSubcomputation(const HloComputation* subcomp); bool ShouldShowFusionSubcomputation(const HloInstruction* instr); + + // We omit some nodes from the graph, instead drawing them inlined into the + // nodes that use them. + bool ShouldMergeIntoUsers(const HloInstruction* instr) const; + string DumpSubcomputation(const HloComputation* subcomp, const HloInstruction* parent_instr); string DumpComputation(const HloComputation* comp); @@ -352,9 +357,24 @@ class HloDotDumper { string GetInstructionNodeLabel(const HloInstruction* instr); string GetInstructionNodeMetadata(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr); - string GetInstructionNodeInlinedConstants(const HloInstruction* instr); + string GetInstructionNodeInlinedOperands(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); + // For most instructions, GetNodeForEdge(instr) returns instr. + // + // The exception is fusion nodes. For these, we walk up the chain of nested + // fusion nodes starting at instr until we reach a node that either (a) isn't + // a fusion node, or (b) is a fusion node for which + // ShouldShowFusionSubcomputation is false. + // + // We do this because fusion nodes are expanded inline -- if + // ShouldShowFusionSubcomputation is true, the fusion node won't be present in + // the graph. + // + // In general when you want to draw an edge from A to B, you should actually + // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B). + const HloInstruction* GetNodeForEdge(const HloInstruction* instr); + // If instr has just one computation and it's trivial (e.g. "return param0 + // param1"), returns a string you can put into the node's body that names the // subcomputation, e.g. "Subcomputation: add". @@ -590,16 +610,15 @@ tooltip = " "; // belongs to a fusion node, it's drawn in place of the fusion instruction, // so there's no need to link those. if (parent_instr->opcode() != HloOpcode::kFusion) { - VLOG(2) << "Edge: from " << subcomp->root_instruction()->name() << " to " - << parent_instr->name() << " as " << next_edge_id_; - edge_ids_.insert( - {{subcomp->root_instruction(), parent_instr}, next_edge_id_++}); + const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); + VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() + << " as " << next_edge_id_; + edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); const char* edge_fmt = R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back( - Printf(edge_fmt, InstructionId(subcomp->root_instruction()), - InstructionId(parent_instr), SubcomputationId(subcomp), - subcomp->name(), parent_instr->name())); + edges_.push_back(Printf( + edge_fmt, InstructionId(from), InstructionId(parent_instr), + SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); } string computation = @@ -628,15 +647,7 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) { } string HloDotDumper::DumpRootTag() { - HloInstruction* from = computation_->root_instruction(); - - // Fusion nodes are expanded inline, so if root is an expanded fusion node, - // walk up the graph until we find a node that isn't. - while (from->opcode() == HloOpcode::kFusion && - ShouldShowFusionSubcomputation(from)) { - from = from->fused_expression_root(); - } - + const HloInstruction* from = GetNodeForEdge(computation_->root_instruction()); auto from_id = InstructionId(from); if (!filter_.Show(from)) { @@ -668,12 +679,42 @@ string HloDotDumper::DumpRootTag() { to_id, node_body, node_shape, NodeColorAttributes(color)); } +bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { + // If a node: + // + // - is a tuple-shaped parameter, + // - is not a parameter to a fusion node, + // - has at least kMinUsersToOmit users shown, and + // - all of the shown users are get-tuple-elements, + // + // then we omit it from the graph, merging it with its users. + // + // This helps us handle the common case where a while loop body has one big + // tuple-shaped parameter. + const int kMinUsersToOmit = 3; + return instr->opcode() == HloOpcode::kParameter && + ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && + std::count_if(instr->users().begin(), instr->users().end(), + [&](const HloInstruction* user) { + return filter_.Show(user); + }) > kMinUsersToOmit && + std::all_of(instr->users().begin(), instr->users().end(), + [&](const HloInstruction* user) { + return !filter_.Show(user) || + user->opcode() == HloOpcode::kGetTupleElement; + }); +} + string HloDotDumper::DumpInstruction(const HloInstruction* instr) { // We don't display constants as separate nodes; they're merged into their // users. if (instr->opcode() == HloOpcode::kConstant) { return ""; } + // Skip this node if it's merged into its users. + if (ShouldMergeIntoUsers(instr)) { + return ""; + } // Omit the fusion node if its subcomputation is drawn, since the // subcomputation will be drawn inline. if (instr->opcode() == HloOpcode::kFusion && @@ -689,7 +730,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string node_label = GetInstructionNodeLabel(instr); string node_metadata = GetInstructionNodeMetadata(instr); string extra_info = GetInstructionNodeExtraInfo(instr); - string inlined_constants = GetInstructionNodeInlinedConstants(instr); + string inlined_constants = GetInstructionNodeInlinedOperands(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); AddInstructionIncomingEdges(instr); @@ -717,7 +758,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { NodeColorAttributes(color)); } -string HloDotDumper::GetInstructionNodeInlinedConstants( +string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { auto stringify_constant = [](const HloInstruction* constant) { if (ShapeUtil::IsEffectiveScalar(constant->shape())) { @@ -726,10 +767,14 @@ string HloDotDumper::GetInstructionNodeInlinedConstants( return Printf("%s (%s)", constant->literal().GetAsString(elem_idx), ShapeUtil::HumanString(constant->shape())); } + string constant_name; if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { - return constant->name(); + constant_name = constant->name(); + } else { + constant_name = StrCat("constant ", constant->name()); } - return StrCat("constant ", constant->name()); + return Printf("%s %s", constant_name, + ShapeUtil::HumanString(constant->shape())); }; // Special case: If instr is a parameter to a fusion node, check whether the @@ -746,16 +791,44 @@ string HloDotDumper::GetInstructionNodeInlinedConstants( std::vector lines; for (int64 i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); - if (operand->opcode() != HloOpcode::kConstant) { - continue; + optional operand_str; + if (operand->opcode() == HloOpcode::kConstant) { + operand_str = stringify_constant(operand); + } else if (ShouldMergeIntoUsers(operand)) { + // Special case: If the operand is a parameter, use its parameter number + // rather than its name, because that's generally how people think of the + // node. + if (operand->opcode() == HloOpcode::kParameter) { + operand_str = Printf("Parameter %lld", operand->parameter_number()); + } else { + operand_str = operand->name(); + } + } + + if (operand_str) { + if (instr->operand_count() > 1) { + lines.push_back(Printf("operand %lld = %s", i, *operand_str)); + } else { + lines.push_back(Printf("operand = %s", *operand_str)); + } } - lines.push_back( - Printf("operand %lld = %s", i, stringify_constant(operand))); } return Join(lines, "
"); } ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { + const auto kParameterColor = kOrange; + + // Special case: If this instruction has a parameter merged into it, paint it + // the same color as a parameter. + if (std::any_of(instr->operands().begin(), instr->operands().end(), + [&](const HloInstruction* operand) { + return operand->opcode() == HloOpcode::kParameter && + ShouldMergeIntoUsers(operand); + })) { + return kParameterColor; + } + // Pick different colors or shapes for instructions which are particularly // expensive (eg, dot) and those which are unusual in some way or unique // (eg, parameter). @@ -763,8 +836,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kAtan2: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kComplex: case HloOpcode::kConvert: case HloOpcode::kCos: case HloOpcode::kDivide: @@ -773,13 +848,14 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: + case HloOpcode::kImag: case HloOpcode::kIndex: case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalNot: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -787,8 +863,11 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kPower: + case HloOpcode::kReal: case HloOpcode::kRemainder: - case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSlice: @@ -796,22 +875,46 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kRng: - case HloOpcode::kBroadcast: - case HloOpcode::kTranspose: + // De-emphasize scalar-shaped elementwise ops -- they're generally + // uninteresting. + if (ShapeUtil::IsEffectiveScalar(instr->shape())) { + return kWhite; + } return kYellow; case HloOpcode::kBitcast: case HloOpcode::kTuple: case HloOpcode::kTrace: case HloOpcode::kGetTupleElement: return kWhite; + case HloOpcode::kBroadcast: + // De-emphasize nodes which broadcast a scalar within a fusion node -- + // these are essentially free. + if (instr->IsFused() && + ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) { + return kWhite; + } + return kGreen; case HloOpcode::kConcatenate: case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kPad: case HloOpcode::kReshape: case HloOpcode::kReverse: - case HloOpcode::kUpdate: + case HloOpcode::kSelect: + case HloOpcode::kTranspose: + // De-emphasize scalar-shaped data movement ops and all data movement ops + // inside fusion nodes, both of which are essentially free. + if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) { + return kWhite; + } + return kGreen; + case HloOpcode::kDynamicUpdateSlice: + // Unlike the data-movement ops above, dynamic-update-slice is not ~free + // inside of fusion nodes, so we de-emphasize it only if it's + // scalar-shaped. + if (ShapeUtil::IsEffectiveScalar(instr->shape())) { + return kWhite; + } return kGreen; case HloOpcode::kConvolution: case HloOpcode::kDot: @@ -819,7 +922,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReducePrecision: return kRed; case HloOpcode::kParameter: - return kOrange; + return kParameterColor; case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: @@ -924,6 +1027,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { [](int64 stride) { return stride == 1; }) ? "" : StrCat("stride=", VectorString(instr->slice_strides())); + case HloOpcode::kSend: + case HloOpcode::kRecv: + return StrCat("channel_id=", instr->channel_id()); default: return ""; } @@ -933,7 +1039,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { if (!opcode_specific_info.empty()) { lines.push_back(opcode_specific_info); } - + if (instr->device_assignment().has_device()) { + lines.push_back(StrCat("device=", instr->device_assignment().device())); + } // Show the shape and layout of the instruction, unless it's an inlined fusion // node -- there the shape and layout is present in the output node. if (instr->opcode() != HloOpcode::kFusion || @@ -978,14 +1086,10 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { - // Fusion nodes' subcomputations are displayed inline, so if 'from' is a - // fusion node and the node's subcomputation is shown, we draw our edge - // starting at the fusion node's root instead of at the fusion node itself. - if (from->opcode() == HloOpcode::kFusion && - ShouldShowFusionSubcomputation(from)) { - from = from->fused_expression_root(); - } - if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) { + from = GetNodeForEdge(from); + + if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant || + ShouldMergeIntoUsers(from)) { return; } VLOG(2) << "Adding edge from " << from->name() << " to " << to->name() @@ -1051,6 +1155,15 @@ string HloDotDumper::GetInstructionTrivialComputationStr( return Join(lines, "
"); } +const HloInstruction* HloDotDumper::GetNodeForEdge( + const HloInstruction* instr) { + while (instr->opcode() == HloOpcode::kFusion && + ShouldShowFusionSubcomputation(instr)) { + instr = instr->fused_expression_root(); + } + return instr; +} + tensorflow::mutex& RendererMutex() { static tensorflow::mutex* mu = new tensorflow::mutex; return *mu; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 7419ab8704a7090170ecf6fbed70236656e8b602..1a03e7ee92cc788b06d9c05750123da26770ff12 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -47,6 +47,101 @@ using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; +/* static */ +StatusOr> HloInstruction::CreateFromProto( + HloModule* module, const HloInstructionProto& proto, + const tensorflow::gtl::FlatMap& instruction_map, + tensorflow::gtl::FlatMap* computation_map) { + TF_RET_CHECK(!proto.opcode().empty()); + TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); + TF_RET_CHECK(proto.has_shape()); + + auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + for (const string& operand_name : proto.operand_names()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_name)) + << "No instruction named " << operand_name; + instruction->AppendOperand(instruction_map.at(operand_name)); + } + for (const string& predecessor_name : proto.control_predecessor_names()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name)) + << "No instruction named " << predecessor_name; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name) + ->AddControlDependencyTo(instruction.get())); + } + + // In the proto, fused computations are held exclusively within the + // HloInstructionProto and do not appear as an HloComputationProto within the + // HloModuleProto. + if (instruction->opcode() == HloOpcode::kFusion) { + TF_RET_CHECK(proto.has_fused_instructions_computation()); + TF_RET_CHECK(!proto.fusion_kind().empty()); + TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, + StringToFusionKind(proto.fusion_kind())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr fused_computation, + HloComputation::CreateFromProto( + module, proto.fused_instructions_computation(), computation_map, + /*fusion_instruction=*/instruction.get())); + instruction->called_computations_.push_back( + module->AddEmbeddedComputation(std::move(fused_computation))); + } else { + for (const string& computation_name : proto.called_computation_names()) { + TF_RET_CHECK(ContainsKey(*computation_map, computation_name)) + << "No computation named " << computation_name; + instruction->called_computations_.push_back( + computation_map->at(computation_name)); + } + } + + TF_RET_CHECK(!proto.name().empty()); + instruction->name_ = proto.name(); + + instruction->metadata_ = proto.metadata(); + if (proto.has_literal()) { + instruction->literal_ = MakeUnique(proto.literal()); + } + instruction->parameter_number_ = proto.parameter_number(); + instruction->parameter_name_ = proto.parameter_name(); + + instruction->tuple_index_ = proto.tuple_index(); + for (int64 dimension : proto.dimensions()) { + instruction->dimensions_.push_back(dimension); + } + if (proto.has_window()) { + instruction->window_ = MakeUnique(proto.window()); + } + if (proto.has_convolution_dimension_numbers()) { + instruction->convolution_dimension_numbers_ = + MakeUnique( + proto.convolution_dimension_numbers()); + } + for (const HloInstructionProto::SliceDimensions& slice_dimensions : + proto.slice_dimensions()) { + instruction->slice_starts_.push_back(slice_dimensions.start()); + instruction->slice_limits_.push_back(slice_dimensions.limit()); + instruction->slice_strides_.push_back(slice_dimensions.stride()); + } + instruction->exponent_bits_ = proto.exponent_bits(); + instruction->mantissa_bits_ = proto.mantissa_bits(); + for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { + instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size); + } + if (proto.has_padding_config()) { + instruction->padding_config_ = + MakeUnique(proto.padding_config()); + } + instruction->outfeed_config_ = proto.outfeed_config(); + instruction->distribution_ = proto.distribution(); + instruction->epsilon_ = proto.epsilon(); + instruction->feature_index_ = proto.feature_index(); + instruction->channel_id_ = proto.channel_id(); + instruction->infeed_config_ = proto.infeed_config(); + instruction->custom_call_target_ = proto.custom_call_target(); + instruction->outfeed_shape_ = proto.outfeed_shape(); + + return std::move(instruction); +} + /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { auto instruction = @@ -124,10 +219,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSort: @@ -146,23 +243,28 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, // Only certain opcodes are supported with CreateBinary: opcodes of binary // instructions with no auxiliary fields. switch (opcode) { - case (HloOpcode::kAdd): - case (HloOpcode::kDivide): - case (HloOpcode::kDot): - case (HloOpcode::kEq): - case (HloOpcode::kGe): - case (HloOpcode::kGt): - case (HloOpcode::kLe): - case (HloOpcode::kLt): - case (HloOpcode::kMaximum): - case (HloOpcode::kMinimum): - case (HloOpcode::kMultiply): - case (HloOpcode::kNe): - case (HloOpcode::kPower): - case (HloOpcode::kRemainder): - case (HloOpcode::kSubtract): - case (HloOpcode::kLogicalAnd): - case (HloOpcode::kLogicalOr): + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kDivide: + case HloOpcode::kComplex: + case HloOpcode::kDot: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: break; default: LOG(FATAL) << "Invalid binary instruction opcode " @@ -618,10 +720,12 @@ void HloInstruction::MergeFusionInstructionIntoMultiOutput( // Fuse the root instruction and generate multiple outputs. FuseInstructionIntoMultiOutput(unfused_root); + TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); // The rest instructions are of normal fusing. for (int64 i = 1; i < unfused_instructions.size(); i++) { auto instruction = unfused_instructions[i]; FuseInstruction(instruction); + TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); } } @@ -864,6 +968,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( VLOG(3) << " " << new_operand->name(); } + std::unique_ptr clone; + // Explicitly call the factory for the instruction type. This is more robust // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. @@ -876,19 +982,24 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSort: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); - return CreateUnary(shape, opcode_, new_operands[0]); + clone = CreateUnary(shape, opcode_, new_operands[0]); + break; // Binary ops. case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: case HloOpcode::kDivide: case HloOpcode::kMultiply: case HloOpcode::kSubtract: @@ -903,125 +1014,162 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kMinimum: case HloOpcode::kPower: case HloOpcode::kRemainder: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: CHECK_EQ(new_operands.size(), 2); - return CreateBinary(shape, opcode_, new_operands[0], new_operands[1]); + clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]); + break; // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: CHECK_EQ(new_operands.size(), 3); - return CreateTernary(shape, opcode_, new_operands[0], new_operands[1], - new_operands[2]); + clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1], + new_operands[2]); + break; // Other supported ops. case HloOpcode::kBroadcast: CHECK_EQ(new_operands.size(), 1); - return CreateBroadcast(shape, new_operands[0], dimensions_); + clone = CreateBroadcast(shape, new_operands[0], dimensions_); + break; case HloOpcode::kCall: - return CreateCall(shape, new_operands, to_apply()); + clone = CreateCall(shape, new_operands, to_apply()); + break; case HloOpcode::kCustomCall: - return CreateCustomCall(shape, new_operands, custom_call_target_); + clone = CreateCustomCall(shape, new_operands, custom_call_target_); + break; case HloOpcode::kConcatenate: - return CreateConcatenate(shape, new_operands, dimensions(0)); + clone = CreateConcatenate(shape, new_operands, dimensions(0)); + break; case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); - return CreateConvert(shape, new_operands[0]); + clone = CreateConvert(shape, new_operands[0]); + break; case HloOpcode::kReducePrecision: CHECK_EQ(new_operands.size(), 1); - return CreateReducePrecision(shape, new_operands[0], exponent_bits_, - mantissa_bits_); + clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_, + mantissa_bits_); + break; case HloOpcode::kConvolution: CHECK_EQ(new_operands.size(), 2); - return CreateConvolve(shape, new_operands[0], new_operands[1], *window_, - *convolution_dimension_numbers_); + clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, + *convolution_dimension_numbers_); + break; case HloOpcode::kCrossReplicaSum: CHECK_EQ(new_operands.size(), 1); - return CreateCrossReplicaSum(shape, new_operands[0]); + clone = CreateCrossReplicaSum(shape, new_operands[0]); + break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); - return CreateGetTupleElement(shape, new_operands[0], tuple_index()); + clone = CreateGetTupleElement(shape, new_operands[0], tuple_index()); + break; case HloOpcode::kMap: - return CreateMap(shape, new_operands, to_apply()); + clone = CreateMap(shape, new_operands, to_apply()); + break; case HloOpcode::kPad: CHECK_EQ(new_operands.size(), 2); - return CreatePad(shape, new_operands[0], new_operands[1], - *padding_config_); + clone = + CreatePad(shape, new_operands[0], new_operands[1], *padding_config_); + break; case HloOpcode::kReduce: CHECK_EQ(new_operands.size(), 2); - return CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, - to_apply()); + clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, + to_apply()); + break; case HloOpcode::kReduceWindow: CHECK_EQ(new_operands.size(), 2); - return CreateReduceWindow(shape, new_operands[0], new_operands[1], - *window_, to_apply()); + clone = CreateReduceWindow(shape, new_operands[0], new_operands[1], + *window_, to_apply()); + break; case HloOpcode::kSelectAndScatter: CHECK_EQ(new_operands.size(), 3); - return CreateSelectAndScatter(shape, new_operands[0], select(), *window_, - new_operands[1], new_operands[2], - scatter()); + clone = + CreateSelectAndScatter(shape, new_operands[0], select(), *window_, + new_operands[1], new_operands[2], scatter()); + break; case HloOpcode::kReverse: CHECK_EQ(new_operands.size(), 1); - return CreateReverse(shape, new_operands[0], dimensions_); + clone = CreateReverse(shape, new_operands[0], dimensions_); + break; case HloOpcode::kRng: - return CreateRng(shape, distribution_, new_operands); + clone = CreateRng(shape, distribution_, new_operands); + break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); - return CreateReshape(shape, new_operands[0]); + clone = CreateReshape(shape, new_operands[0]); + break; case HloOpcode::kSlice: CHECK_EQ(new_operands.size(), 1); - return CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_, - slice_strides_); + clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_, + slice_strides_); + break; case HloOpcode::kDynamicSlice: - return CreateDynamicSlice(shape, new_operands[0], new_operands[1], - dynamic_slice_sizes_); + clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1], + dynamic_slice_sizes_); + break; case HloOpcode::kDynamicUpdateSlice: CHECK_EQ(new_operands.size(), 3); - return CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], - new_operands[2]); + clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], + new_operands[2]); + break; case HloOpcode::kTranspose: CHECK_EQ(new_operands.size(), 1); - return CreateTranspose(shape, new_operands[0], dimensions_); + clone = CreateTranspose(shape, new_operands[0], dimensions_); + break; case HloOpcode::kTuple: - return CreateTuple(new_operands); + clone = CreateTuple(new_operands); + *clone->mutable_shape() = shape; + break; case HloOpcode::kWhile: CHECK_EQ(new_operands.size(), 1); - return CreateWhile(shape, while_condition(), while_body(), - new_operands[0]); + clone = + CreateWhile(shape, while_condition(), while_body(), new_operands[0]); + break; case HloOpcode::kConstant: - return CreateConstant(literal_->CloneToUnique()); + clone = CreateConstant(literal_->CloneToUnique()); + break; case HloOpcode::kFusion: - return CloneFusionWithNewOperands(shape, new_operands); + clone = CloneFusionWithNewOperands(shape, new_operands); + break; case HloOpcode::kParameter: - return CreateParameter(parameter_number_, shape, parameter_name_); + clone = CreateParameter(parameter_number_, shape, parameter_name_); + break; case HloOpcode::kBatchNormTraining: CHECK_EQ(new_operands.size(), 3); - return CreateBatchNormTraining(shape, new_operands[0], new_operands[1], - new_operands[2], epsilon(), - feature_index()); - + clone = + CreateBatchNormTraining(shape, new_operands[0], new_operands[1], + new_operands[2], epsilon(), feature_index()); + break; case HloOpcode::kBatchNormInference: CHECK_EQ(new_operands.size(), 5); - return CreateBatchNormInference( + clone = CreateBatchNormInference( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); + break; case HloOpcode::kInfeed: CHECK_EQ(new_operands.size(), 0); - return CreateInfeed(shape, infeed_config()); + clone = CreateInfeed(shape, infeed_config()); + break; case HloOpcode::kOutfeed: CHECK_EQ(new_operands.size(), 1); - return CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); + clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); + break; case HloOpcode::kBatchNormGrad: CHECK_EQ(new_operands.size(), 5); - return CreateBatchNormGrad(shape, new_operands[0], new_operands[1], - new_operands[2], new_operands[3], - new_operands[4], epsilon(), feature_index()); + clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1], + new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); + break; case HloOpcode::kRecv: case HloOpcode::kSend: - case HloOpcode::kUpdate: case HloOpcode::kIndex: case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } + clone->set_metadata(metadata_); + return clone; } HloInstruction::~HloInstruction() {} @@ -1064,7 +1212,6 @@ std::unique_ptr HloInstruction::Clone( } } clone->set_parent(parent_); - clone->set_metadata(metadata_); return clone; } @@ -1131,6 +1278,29 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( return new_instruction; } +std::pair +HloInstruction::LatestNonGteAncestorAndIndex() const { + const HloInstruction* hlo = this; + ShapeIndex index; + while (hlo->opcode() == HloOpcode::kGetTupleElement) { + index.push_back(hlo->tuple_index()); + hlo = hlo->operand(0); + } + + // We built up index in the reverse order from what we want. + std::reverse(index.begin(), index.end()); + + return {hlo, index}; +} + +const HloInstruction* HloInstruction::LatestNonGteAncestor() const { + const HloInstruction* hlo = this; + while (hlo->opcode() == HloOpcode::kGetTupleElement) { + hlo = hlo->operand(0); + } + return hlo; +} + const Literal& HloInstruction::literal() const { CHECK_EQ(HloOpcode::kConstant, opcode_); return *literal_; @@ -1241,10 +1411,12 @@ bool HloInstruction::IdenticalSlowPath( // The result of these instructions only depend upon their opcode and // operands. case HloOpcode::kAbs: + case HloOpcode::kAtan2: case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kComplex: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: @@ -1255,12 +1427,13 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: + case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalNot: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -1268,8 +1441,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kPower: + case HloOpcode::kReal: case HloOpcode::kRemainder: case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSubtract: @@ -1375,7 +1552,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: - case HloOpcode::kUpdate: case HloOpcode::kSend: case HloOpcode::kRecv: return false; @@ -1689,16 +1865,20 @@ std::vector HloInstruction::ExtraAttributesToString() const { } if (opcode() == HloOpcode::kWhile) { - extra.push_back(StrCat("condition=", while_condition()->name())); - extra.push_back(StrCat("body=", while_body()->name())); + extra.push_back(StrCat("condition=%", while_condition()->name())); + extra.push_back(StrCat("body=%", while_body()->name())); } else if (opcode() == HloOpcode::kSelectAndScatter) { - extra.push_back(StrCat("select=", select()->name())); - extra.push_back(StrCat("scatter=", scatter()->name())); + extra.push_back(StrCat("select=%", select()->name())); + extra.push_back(StrCat("scatter=%", scatter()->name())); + } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || + opcode() == HloOpcode::kReduceWindow || + opcode() == HloOpcode::kReduce) { + extra.push_back(StrCat("to_apply=%", to_apply()->name())); } else if (!called_computations().empty()) { extra.push_back(StrCat( "calls=", Join(called_computations(), ", ", [](string* out, const HloComputation* computation) { - StrAppend(out, computation->name()); + StrAppend(out, "%", computation->name()); }))); } @@ -1709,6 +1889,9 @@ std::vector HloInstruction::ExtraAttributesToString() const { if (opcode() == HloOpcode::kGetTupleElement) { extra.push_back(StrCat("index=", tuple_index())); } + if (device_assignment_.has_device()) { + extra.push_back(StrCat("device=", device_assignment_.device())); + } if (!control_successors_.empty()) { extra.push_back(StrCat( "control-successors=", @@ -1739,37 +1922,59 @@ HloInstructionProto HloInstruction::ToProto() const { for (const HloInstruction* control : control_predecessors_) { *proto.add_control_predecessor_names() = control->name(); } - for (const HloComputation* computation : called_computations_) { - *proto.add_called_computation_names() = computation->name(); - } + *proto.mutable_metadata() = metadata_; - switch (opcode_) { - case HloOpcode::kConstant: - *proto.mutable_literal() = literal_->ToProto(); - break; - case HloOpcode::kParameter: - proto.set_parameter_number(parameter_number_); - proto.set_parameter_name(parameter_name_); - break; - case HloOpcode::kFusion: { - HloComputationProto* proto_fused_computation = - proto.mutable_fused_instructions_computation(); - proto_fused_computation->set_name(name()); - - // Fill in fused instructions in post order. - auto fused_instructions = - fused_instructions_computation()->MakeInstructionPostOrder(); - for (auto fused_instruction : fused_instructions) { - HloInstructionProto fused_proto = fused_instruction->ToProto(); - proto_fused_computation->add_instructions()->Swap(&fused_proto); - } - break; + if (literal_ != nullptr) { + *proto.mutable_literal() = literal_->ToProto(); + } + proto.set_parameter_number(parameter_number_); + proto.set_parameter_name(parameter_name_); + if (opcode() == HloOpcode::kFusion) { + proto.set_fusion_kind(xla::ToString(fusion_kind())); + *proto.mutable_fused_instructions_computation() = + fused_instructions_computation()->ToProto(); + } else { + for (const HloComputation* computation : called_computations_) { + *proto.add_called_computation_names() = computation->name(); } - case HloOpcode::kGetTupleElement: - proto.set_tuple_index(tuple_index_); - break; - default: {} // Nothing to do } + + proto.set_tuple_index(tuple_index_); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + if (window_ != nullptr) { + *proto.mutable_window() = *window_; + } + if (convolution_dimension_numbers_ != nullptr) { + *proto.mutable_convolution_dimension_numbers() = + *convolution_dimension_numbers_; + } + for (int i = 0; i < slice_starts_.size(); ++i) { + auto* slice_dimension = proto.add_slice_dimensions(); + slice_dimension->set_start(slice_starts_[i]); + slice_dimension->set_limit(slice_limits_[i]); + slice_dimension->set_stride(slice_strides_[i]); + } + proto.set_exponent_bits(exponent_bits_); + proto.set_mantissa_bits(mantissa_bits_); + for (int64 slice_size : dynamic_slice_sizes_) { + proto.add_dynamic_slice_sizes(slice_size); + } + if (padding_config_ != nullptr) { + *proto.mutable_padding_config() = *padding_config_; + } + proto.set_outfeed_config(outfeed_config_); + if (opcode() == HloOpcode::kRng) { + proto.set_distribution(distribution_); + } + proto.set_epsilon(epsilon_); + proto.set_feature_index(feature_index_); + proto.set_channel_id(channel_id_); + proto.set_infeed_config(infeed_config_); + proto.set_custom_call_target(custom_call_target_); + *proto.mutable_outfeed_shape() = outfeed_shape_; + return proto; } @@ -1924,6 +2129,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { switch (opcode_) { case HloOpcode::kAbs: return visitor->HandleAbs(this, operands_[0]); + case HloOpcode::kAtan2: + return visitor->HandleAtan2(this, operands_[0], operands_[1]); case HloOpcode::kRoundNearestAfz: return visitor->HandleRound(this); case HloOpcode::kBatchNormTraining: @@ -1947,6 +2154,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kLt: case HloOpcode::kNe: return visitor->HandleCompare(this, opcode_, operands_[0], operands_[1]); + case HloOpcode::kComplex: + return visitor->HandleComplex(this, operands_[0], operands_[1]); case HloOpcode::kAdd: return visitor->HandleAdd(this, operands_[0], operands_[1]); case HloOpcode::kDivide: @@ -1957,10 +2166,17 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleMaximum(this); case HloOpcode::kMinimum: return visitor->HandleMinimum(this); - case HloOpcode::kLogicalAnd: - return visitor->HandleLogicalAnd(this, operands_[0], operands_[1]); - case HloOpcode::kLogicalOr: - return visitor->HandleLogicalOr(this, operands_[0], operands_[1]); + case HloOpcode::kAnd: + return visitor->HandleAnd(this, operands_[0], operands_[1]); + case HloOpcode::kOr: + return visitor->HandleOr(this, operands_[0], operands_[1]); + case HloOpcode::kShiftLeft: + return visitor->HandleShiftLeft(this, operands_[0], operands_[1]); + case HloOpcode::kShiftRightArithmetic: + return visitor->HandleShiftRightArithmetic(this, operands_[0], + operands_[1]); + case HloOpcode::kShiftRightLogical: + return visitor->HandleShiftRightLogical(this, operands_[0], operands_[1]); case HloOpcode::kConcatenate: return visitor->HandleConcatenate(this, operands_); case HloOpcode::kConvert: @@ -2014,10 +2230,14 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleCos(this, operands_[0]); case HloOpcode::kSin: return visitor->HandleSin(this, operands_[0]); + case HloOpcode::kReal: + return visitor->HandleReal(this, operands_[0]); + case HloOpcode::kImag: + return visitor->HandleImag(this, operands_[0]); case HloOpcode::kIsFinite: return visitor->HandleIsFinite(this, operands_[0]); - case HloOpcode::kLogicalNot: - return visitor->HandleLogicalNot(this, operands_[0]); + case HloOpcode::kNot: + return visitor->HandleNot(this, operands_[0]); case HloOpcode::kBitcast: return visitor->HandleBitcast(this); case HloOpcode::kBroadcast: @@ -2063,7 +2283,6 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { // These opcodes are not handled here. case HloOpcode::kIndex: case HloOpcode::kTrace: - case HloOpcode::kUpdate: break; } return Unimplemented("unhandled HloOpcode for DfsHloVisitor: %s", @@ -2106,7 +2325,7 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, // // We need to keep track of both the id and the instruction because // instructions can get deleted while they are on the stack, so we - // can't always use the (potentiall dead) instruction object to grab + // can't always use the (potentially dead) instruction object to grab // its id. DFSStack dfs_stack; dfs_stack.emplace_back(root->unique_id(), root); @@ -2306,6 +2525,7 @@ bool HloInstruction::IsElementwiseBinary() const { // Binary elementwise operations. If you update this, please update // IsElementwise() accordingly. case HloOpcode::kAdd: + case HloOpcode::kComplex: case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kGe: @@ -2319,8 +2539,11 @@ bool HloInstruction::IsElementwiseBinary() const { case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return true; default: return false; @@ -2335,6 +2558,7 @@ bool HloInstruction::IsElementwise() const { // Unary elementwise operations. case HloOpcode::kAbs: + case HloOpcode::kAtan2: case HloOpcode::kRoundNearestAfz: case HloOpcode::kCeil: case HloOpcode::kConvert: @@ -2342,10 +2566,12 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kReal: case HloOpcode::kReducePrecision: case HloOpcode::kSign: case HloOpcode::kSin: @@ -2355,6 +2581,7 @@ bool HloInstruction::IsElementwise() const { // Binary elementwise operations, the same as in IsElementwiseBinary(). // If you update this, please update IsElementwiseBinary() accordingly. case HloOpcode::kAdd: + case HloOpcode::kComplex: case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kGe: @@ -2368,8 +2595,11 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return true; // Ternary elementwise operations. @@ -2460,10 +2690,10 @@ class HloInstruction::FusionReusesParamElements { public: using UseKind = HloInstruction::UseKind; - // We could rather iterate backwards thru fused_instructions_ here, as it is - // in reverse postorder, and compute whether each fused instruction reuses - // the value of this parameter, which would save stack space but not allow - // us to finish early if we find a reuse. + // We could rather iterate backwards through fused_instructions_ here, as it + // is in reverse postorder, and compute whether each fused instruction reuses + // the value of this parameter, which would save stack space but not allow us + // to finish early if we find a reuse. static UseKind Compute(int64 i, const HloInstruction& hlo) { tensorflow::gtl::FlatMap memoization_cache; return ComputeInternal(i, hlo, &memoization_cache); @@ -2588,6 +2818,32 @@ string ToString(HloInstruction::FusionKind kind) { } } +StatusOr StringToFusionKind( + const string& kind_name) { + if (kind_name == "kLoop") { + return HloInstruction::FusionKind::kLoop; + } + if (kind_name == "kInput") { + return HloInstruction::FusionKind::kInput; + } + if (kind_name == "kOutput") { + return HloInstruction::FusionKind::kOutput; + } + if (kind_name == "kTransposeDot") { + return HloInstruction::FusionKind::kTransposeDot; + } + if (kind_name == "kConvBackwardFilter") { + return HloInstruction::FusionKind::kConvBackwardFilter; + } + if (kind_name == "kConvBackwardInput") { + return HloInstruction::FusionKind::kConvBackwardInput; + } + if (kind_name == "kCustom") { + return HloInstruction::FusionKind::kCustom; + } + return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } @@ -2615,8 +2871,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". std::vector lhs_dims(2 + dnums.spatial_dimensions().size()); - lhs_dims[dnums.batch_dimension()] = 'b'; - lhs_dims[dnums.feature_dimension()] = 'f'; + lhs_dims[dnums.input_batch_dimension()] = 'b'; + lhs_dims[dnums.input_feature_dimension()] = 'f'; for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { lhs_dims[dnums.spatial_dimensions(i)] = StrCat(i); } @@ -2628,12 +2884,19 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i); } + std::vector output_dims(2 + dnums.spatial_dimensions().size()); + output_dims[dnums.output_batch_dimension()] = 'b'; + output_dims[dnums.output_feature_dimension()] = 'f'; + for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { + output_dims[dnums.spatial_dimensions(i)] = StrCat(i); + } + result += "dim_labels="; append_dims(lhs_dims, operand(0)->shape()); result += "_"; append_dims(rhs_dims, operand(1)->shape()); result += "->"; - append_dims(lhs_dims, shape()); + append_dims(output_dims, shape()); return result; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 73c4ebd9f11c722c699f3eb6e78703e837fe57db..d2a15b0f962317cc79ab93cf377a77939a5eba41 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -72,6 +72,23 @@ class HloInstruction { }; ~HloInstruction(); + + // Creates an instruction from the given proto. Arguments: + // + // module: the module which will contain the instruction. The newly created + // instruction is *not* added to the module or any computation, however. + // proto: the proto to convert from. + // instruction_map: a map from instruction name to HloInstruction*. This map + // must contain all operands of the newly constructed instruction. + // computation_map: a map from computation name to HloComputation*. This map + // must contain all computations which the newly constructed instruction + // calls. If the instruction is a fusion instruction, then the fusion + // computation is added to this map and the module. + static StatusOr> CreateFromProto( + HloModule* module, const HloInstructionProto& proto, + const tensorflow::gtl::FlatMap& instruction_map, + tensorflow::gtl::FlatMap* computation_map); + // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, const Shape& shape, @@ -508,6 +525,26 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kGetTupleElement int64 tuple_index() const; + // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. + // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the + // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. + std::pair LatestNonGteAncestorAndIndex() + const; + + std::pair LatestNonGteAncestorAndIndex() { + auto rv = + const_cast(this)->LatestNonGteAncestorAndIndex(); + return {const_cast(rv.first), rv.second}; + } + + // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction. + const HloInstruction* LatestNonGteAncestor() const; + + HloInstruction* LatestNonGteAncestor() { + return const_cast( + const_cast(this)->LatestNonGteAncestor()); + } + // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. // The setter should only be called by HloModule or HloComputation methods. // @@ -1055,7 +1092,7 @@ class HloInstruction { std::unique_ptr literal_; // Constant index, only present for kGetTupleElement. - int64 tuple_index_ = 0; + int64 tuple_index_ = -1; // Dimensions present for some operations that require reshaping or // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. @@ -1073,8 +1110,8 @@ class HloInstruction { std::vector slice_strides_; // The bit sizes for a reduce-precision operation. - int32 exponent_bits_; - int32 mantissa_bits_; + int32 exponent_bits_ = 0; + int32 mantissa_bits_ = 0; // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). @@ -1124,11 +1161,11 @@ class HloInstruction { // A small float number added to the variance to avoid divide-by-zero error. // Only present for kBatchNormTraining. - float epsilon_; + float epsilon_ = 0.0f; // An integer value representing the index of the feature dimension. // Only present for kBatchNormTraining. - int64 feature_index_; + int64 feature_index_ = -1; // Represents a unique identifier for each Send/Recv instruction pair. // Only present for kSend or kRecv. @@ -1154,6 +1191,8 @@ class HloInstruction { }; string ToString(HloInstruction::FusionKind kind); +StatusOr StringToFusionKind( + const string& kind_name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 3601d5cdbe66b305b4fa23fa5e1c519704befbb9..9affecae6072ddfaf81b75caaf70d2ba2c68bdaa 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -706,6 +706,9 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { metadata, fusion->fused_expression_root()->metadata())); EXPECT_TRUE(protobuf_util::ProtobufEquals( metadata, fusion->fused_expression_root()->operand(0)->metadata())); + + auto cloned = fusion->CloneWithNewOperands(fusion->shape(), {}); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata())); } TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { @@ -729,6 +732,23 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10)); } +TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { + HloComputation::Builder builder(TestName()); + auto* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({ + {1, 2}, + {3, 4}, + }))); + auto* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {0}) + ->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {1}) + ->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + auto tuple_clone = tuple->Clone(); + EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape())); +} + TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1183,13 +1203,13 @@ TEST_F(HloInstructionTest, Stringification) { EXPECT_EQ(fusion->ToString(false, false), "%fusion = f32[5,20]{1,0} fusion:kTransposeDot(f32[5,10]{1,0} %x, " - "f32[20,10]{1,0} %y), calls=fused_computation"); + "f32[20,10]{1,0} %y), calls=%fused_computation"); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); EXPECT_EQ(loop->ToString(false, false), "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), " - "condition=TransposeDot, body=TransposeDot"); + "condition=%TransposeDot, body=%TransposeDot"); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index b1b3dd61a63d8f729912c5b533099f739f9aa9c4..5440ed2eda90e5cdb81c205a07ae7691a947b44b 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -79,9 +79,9 @@ HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); HLO_MATCHER(Le); HLO_MATCHER(Log); -HLO_MATCHER(LogicalAnd); -HLO_MATCHER(LogicalNot); -HLO_MATCHER(LogicalOr); +HLO_MATCHER(And); +HLO_MATCHER(Not); +HLO_MATCHER(Or); HLO_MATCHER(Lt); HLO_MATCHER(Map); HLO_MATCHER(Maximum); @@ -104,6 +104,9 @@ HLO_MATCHER(Rng); HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); +HLO_MATCHER(ShiftLeft); +HLO_MATCHER(ShiftRightLogical); +HLO_MATCHER(ShiftRightArithmetic); HLO_MATCHER(Sign); HLO_MATCHER(Slice); HLO_MATCHER(Sort); @@ -112,7 +115,6 @@ HLO_MATCHER(Tanh); HLO_MATCHER(Trace); HLO_MATCHER(Transpose); HLO_MATCHER(Tuple); -HLO_MATCHER(Update); HLO_MATCHER(While); #undef HLO_MATCHER } // namespace opcode_matchers diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 14590112a1edd16c0c1ab16d9e1d2aac5ce66e18..1758f2760c46a5f0f5876ac6ba8dd013e71455b6 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -45,10 +45,37 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(name), config_(config) {} HloComputation* HloModule::AddComputationInternal( - std::unique_ptr computation) { - computation->UniquifyName(&computation_name_uniquer_); + std::unique_ptr computation, bool is_entry, + bool uniquify_names) { + if (is_entry) { + CHECK_EQ(nullptr, entry_computation_); + entry_computation_ = computation.get(); + + // If the module configuration has no entry layout computation set, create a + // default one based on the program shape. + if (!config_.has_entry_computation_layout()) { + config_.SetDefaultComputationLayout( + entry_computation_->ComputeProgramShape()); + } + } + + if (uniquify_names) { + computation->UniquifyName(&computation_name_uniquer_); + for (auto* instruction : computation->instructions()) { + instruction->UniquifyName(&instruction_name_uniquer_); + } + } else { + // Don't uniquify the names of the computation or instruction, but we must + // run the names through the uniquifiers to prevent future name collisions + // for computations and instructions created later. + computation_name_uniquer_.GetUniqueName(computation->name()); + for (auto* instruction : computation->instructions()) { + instruction_name_uniquer_.GetUniqueName(instruction->name()); + } + } + + // Pick unique IDs for each instruction. for (auto* instruction : computation->instructions()) { - instruction->UniquifyName(&instruction_name_uniquer_); instruction->SetUniqueId(NewUniqueInstructionId()); } computation->set_parent(this); @@ -58,16 +85,8 @@ HloComputation* HloModule::AddComputationInternal( HloComputation* HloModule::AddEntryComputation( std::unique_ptr computation) { - CHECK_EQ(nullptr, entry_computation_); - entry_computation_ = computation.get(); - - // If the module configuration has no entry layout computation set, create a - // default one based on the program shape. - if (!config_.has_entry_computation_layout()) { - config_.SetDefaultComputationLayout( - entry_computation_->ComputeProgramShape()); - } - return AddComputationInternal(std::move(computation)); + return AddComputationInternal(std::move(computation), /*is_entry=*/true, + /*uniquify_names=*/true); } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { @@ -83,7 +102,8 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr computation) { - return AddComputationInternal(std::move(computation)); + return AddComputationInternal(std::move(computation), /*is_entry=*/false, + /*uniquify_names=*/true); } void HloModule::ReplaceComputations( @@ -153,11 +173,17 @@ void HloModule::ReplaceComputations( string HloModule::ToString() const { std::ostringstream s; s << "HloModule " << name() << ":\n\n"; - s << "ENTRY " << entry_computation()->ToString() << "\n\n"; - for (const std::unique_ptr& computation : computations_) { - if (computation.get() != entry_computation()) { - s << computation->ToString() << "\n\n"; + for (const HloComputation* computation : MakeComputationPostOrder()) { + // Fusion computations are emitted with their fusion instruction and + // therefore don't need to be emitted as a separate comptutation in the + // module. + if (computation->IsFusionComputation()) { + continue; + } + if (computation == entry_computation()) { + s << "ENTRY "; } + s << computation->ToString() << "\n\n"; } return s.str(); } @@ -167,12 +193,166 @@ HloModuleProto HloModule::ToProto() const { proto.set_name(name_); proto.set_entry_computation_name(entry_computation_->name()); for (const HloComputation* computation : MakeComputationPostOrder()) { + // Fusion computations are added when the fusion instructions are created by + // HloInstruction::CreateFromProto. + if (computation->IsFusionComputation()) { + continue; + } HloComputationProto computation_proto = computation->ToProto(); proto.add_computations()->Swap(&computation_proto); } return proto; } +namespace { + +// Construct a ProgramShape matching the shape of the parameters and root of the +// given module's entry computation. +StatusOr ProgramShapeFromProto(const HloModuleProto& module) { + const HloComputationProto* entry_computation = nullptr; + for (const HloComputationProto& computation : module.computations()) { + if (computation.name() == module.entry_computation_name()) { + entry_computation = &computation; + break; + } + } + TF_RET_CHECK(entry_computation != nullptr) + << "No computation with entry computation name" + << module.entry_computation_name(); + + tensorflow::gtl::FlatMap> parameters; + const HloInstructionProto* root = nullptr; + for (const HloInstructionProto& instruction : + entry_computation->instructions()) { + if (instruction.name() == entry_computation->root_name()) { + TF_RET_CHECK(root == nullptr) << "Entry computation has more than " + "one instruction with (root) name " + << instruction.name(); + root = &instruction; + } + if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) { + TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number())) + << "Entry computation has more than one parameter instruction " + "with parameter number " + << instruction.parameter_number(); + parameters[instruction.parameter_number()] = { + instruction.parameter_name(), &instruction.shape()}; + } + } + TF_RET_CHECK(root != nullptr) + << "Entry computation is missing root instruction named " + << entry_computation->root_name(); + + ProgramShape program_shape; + *program_shape.mutable_result() = root->shape(); + for (int64 i = 0; i < parameters.size(); ++i) { + TF_RET_CHECK(ContainsKey(parameters, i)) + << "Entry computation missing parameter number " << i; + const string& name = parameters.at(i).first; + const Shape& shape = *parameters.at(i).second; + *program_shape.add_parameters() = shape; + program_shape.add_parameter_names(name); + } + + return std::move(program_shape); +} + +} // namespace + +/* static */ +StatusOr> HloModule::CreateFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config, + const VersionedComputationHandle& entry_computation_handle) { + // The ProgramShape in the passed in module config must match the shapes of + // the entry parameters and root. + TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape, + ProgramShapeFromProto(proto)); + TF_RET_CHECK(expected_program_shape.parameters_size() == + module_config.entry_computation_layout().parameter_count()); + for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { + const Shape& parameter_shape = + module_config.entry_computation_layout().parameter_layout(i).shape(); + TF_RET_CHECK( + ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape)) + << "HloModuleConfig has different shape for parameter " << i + << " than the HLO module. Expected: " + << ShapeUtil::HumanStringWithLayout( + expected_program_shape.parameters(i)) + << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); + } + const Shape& result_shape = + module_config.entry_computation_layout().result_layout().shape(); + TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape)) + << "HloModuleConfig has different result shape than the HLO module. " + "Expected: " + << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) + << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); + + auto module = MakeUnique(proto.name(), entry_computation_handle, + module_config); + + tensorflow::gtl::FlatMap computation_map; + for (const HloComputationProto& computation_proto : proto.computations()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr computation, + HloComputation::CreateFromProto( + module.get(), computation_proto, &computation_map)); + CHECK_NE(computation.get(), nullptr); + TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); + string computation_name = computation->name(); + // Don't uniquify names because we want names to be stable across + // serialization and deserialization. + computation_map[computation_name] = module->AddComputationInternal( + std::move(computation), + /*is_entry=*/proto.entry_computation_name() == computation_name, + /*uniquify_names=*/false); + } + TF_RET_CHECK(module->entry_computation_ != nullptr); + + // Because we didn't uniquify the names, double-check that the instruction and + // computation names are unique from the proto. + tensorflow::gtl::FlatSet computation_names; + tensorflow::gtl::FlatSet instruction_names; + for (HloComputation* computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + + TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) + << "Computation name is not unique: " << computation->name(); + computation_names.insert(computation->name()); + for (HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) + << "Instruction name is not unique: " << instruction->name(); + instruction_names.insert(instruction->name()); + } + } + + return std::move(module); +} + +/* static */ +StatusOr HloModule::CreateModuleConfigFromProto( + const HloModuleProto& module) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + ProgramShapeFromProto(module)); + + HloModuleConfig module_config(program_shape); + + // The module config is constructed with default layouts regardless of what is + // passed in via the ProgramShape. Set the layouts to the appropriate values. + ComputationLayout* entry_layout = + module_config.mutable_entry_computation_layout(); + for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { + TF_RETURN_IF_ERROR( + entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + program_shape.parameters(i))); + } + TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( + program_shape.result())); + + return module_config; +} + namespace { // Returns whether `hlo` is used outside the given subcomputation. // `instructions_in_subcomputation` is the instruction set of the given diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 3546f4b3f7982d4bb325888da2438965aa946b7c..ad11d56006a79b509309daba55e94342911f76a1 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -140,7 +140,18 @@ class HloModule { const HloModuleConfig& config() const { return config_; } string ToString() const; + + // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; + static StatusOr> CreateFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config, + const VersionedComputationHandle& entry_computation_handle = + VersionedComputationHandle()); + + // Creates and returns an HloModuleConfig with an appropriate program shape + // for the HLO module in the given proto. + static StatusOr CreateModuleConfigFromProto( + const HloModuleProto& module); // Outlines the given expression from the given computation. // instructions_to_outline contains the instructions that form the expression. @@ -176,7 +187,8 @@ class HloModule { private: HloComputation* AddComputationInternal( - std::unique_ptr computation); + std::unique_ptr computation, bool is_entry, + bool uniquify_names); const string name_; HloModuleConfig config_; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 83fe6ef6c967f865333eff51b04a33b1d11ffa7e..d94c4da5eab174493d40c54b572f482a602dee71 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -31,6 +33,10 @@ string HloOpcodeString(HloOpcode opcode) { return "abs"; case HloOpcode::kAdd: return "add"; + case HloOpcode::kAnd: + return "and"; + case HloOpcode::kAtan2: + return "atan2"; case HloOpcode::kBatchNormTraining: return "batch-norm-training"; case HloOpcode::kBatchNormInference: @@ -45,6 +51,8 @@ string HloOpcodeString(HloOpcode opcode) { return "call"; case HloOpcode::kClamp: return "clamp"; + case HloOpcode::kComplex: + return "complex"; case HloOpcode::kConcatenate: return "concatenate"; case HloOpcode::kConstant: @@ -85,6 +93,8 @@ string HloOpcodeString(HloOpcode opcode) { return "get-tuple-element"; case HloOpcode::kGt: return "greater-than"; + case HloOpcode::kImag: + return "imag"; case HloOpcode::kIndex: return "index"; case HloOpcode::kInfeed: @@ -95,12 +105,6 @@ string HloOpcodeString(HloOpcode opcode) { return "less-than-or-equal-to"; case HloOpcode::kLog: return "log"; - case HloOpcode::kLogicalAnd: - return "logical-and"; - case HloOpcode::kLogicalOr: - return "logical-or"; - case HloOpcode::kLogicalNot: - return "logical-not"; case HloOpcode::kLt: return "less-than"; case HloOpcode::kMap: @@ -115,6 +119,10 @@ string HloOpcodeString(HloOpcode opcode) { return "not-equal-to"; case HloOpcode::kNegate: return "negate"; + case HloOpcode::kNot: + return "not"; + case HloOpcode::kOr: + return "or"; case HloOpcode::kOutfeed: return "outfeed"; case HloOpcode::kPad: @@ -123,6 +131,8 @@ string HloOpcodeString(HloOpcode opcode) { return "parameter"; case HloOpcode::kPower: return "power"; + case HloOpcode::kReal: + return "real"; case HloOpcode::kRecv: return "recv"; case HloOpcode::kReduce: @@ -147,6 +157,12 @@ string HloOpcodeString(HloOpcode opcode) { return "select"; case HloOpcode::kSend: return "send"; + case HloOpcode::kShiftLeft: + return "shift-left"; + case HloOpcode::kShiftRightArithmetic: + return "shift-right-arithmetic"; + case HloOpcode::kShiftRightLogical: + return "shift-right-logical"; case HloOpcode::kSign: return "sign"; case HloOpcode::kSin: @@ -165,13 +181,93 @@ string HloOpcodeString(HloOpcode opcode) { return "transpose"; case HloOpcode::kTuple: return "tuple"; - case HloOpcode::kUpdate: - return "update"; case HloOpcode::kWhile: return "while"; } } +StatusOr StringToHloOpcode(const string& opcode_name) { + static auto* opcode_map = new tensorflow::gtl::FlatMap( + {{"abs", HloOpcode::kAbs}, + {"add", HloOpcode::kAdd}, + {"and", HloOpcode::kAnd}, + {"batch-norm-training", HloOpcode::kBatchNormTraining}, + {"batch-norm-inference", HloOpcode::kBatchNormInference}, + {"batch-norm-grad", HloOpcode::kBatchNormGrad}, + {"bitcast", HloOpcode::kBitcast}, + {"broadcast", HloOpcode::kBroadcast}, + {"call", HloOpcode::kCall}, + {"clamp", HloOpcode::kClamp}, + {"concatenate", HloOpcode::kConcatenate}, + {"constant", HloOpcode::kConstant}, + {"convert", HloOpcode::kConvert}, + {"convolution", HloOpcode::kConvolution}, + {"cosine", HloOpcode::kCos}, + {"cross-replica-sum", HloOpcode::kCrossReplicaSum}, + {"custom-call", HloOpcode::kCustomCall}, + {"copy", HloOpcode::kCopy}, + {"divide", HloOpcode::kDivide}, + {"dot", HloOpcode::kDot}, + {"dynamic-slice", HloOpcode::kDynamicSlice}, + {"dynamic-update-slice", HloOpcode::kDynamicUpdateSlice}, + {"equal-to", HloOpcode::kEq}, + {"exponential", HloOpcode::kExp}, + {"floor", HloOpcode::kFloor}, + {"ceil", HloOpcode::kCeil}, + {"fusion", HloOpcode::kFusion}, + {"greater-than-or-equal-to", HloOpcode::kGe}, + {"get-tuple-element", HloOpcode::kGetTupleElement}, + {"greater-than", HloOpcode::kGt}, + {"index", HloOpcode::kIndex}, + {"infeed", HloOpcode::kInfeed}, + {"is-finite", HloOpcode::kIsFinite}, + {"less-than-or-equal-to", HloOpcode::kLe}, + {"log", HloOpcode::kLog}, + {"less-than", HloOpcode::kLt}, + {"map", HloOpcode::kMap}, + {"maximum", HloOpcode::kMaximum}, + {"minimum", HloOpcode::kMinimum}, + {"multiply", HloOpcode::kMultiply}, + {"not", HloOpcode::kNot}, + {"not-equal-to", HloOpcode::kNe}, + {"negate", HloOpcode::kNegate}, + {"or", HloOpcode::kOr}, + {"outfeed", HloOpcode::kOutfeed}, + {"pad", HloOpcode::kPad}, + {"parameter", HloOpcode::kParameter}, + {"power", HloOpcode::kPower}, + {"recv", HloOpcode::kRecv}, + {"reduce", HloOpcode::kReduce}, + {"reduce-precision", HloOpcode::kReducePrecision}, + {"reduce-window", HloOpcode::kReduceWindow}, + {"remainder", HloOpcode::kRemainder}, + {"reshape", HloOpcode::kReshape}, + {"reverse", HloOpcode::kReverse}, + {"rng", HloOpcode::kRng}, + {"round-nearest-afz", HloOpcode::kRoundNearestAfz}, + {"select-and-scatter", HloOpcode::kSelectAndScatter}, + {"select", HloOpcode::kSelect}, + {"send", HloOpcode::kSend}, + {"shift-left", HloOpcode::kShiftLeft}, + {"shift-right-arithmetic", HloOpcode::kShiftRightArithmetic}, + {"shift-right-logical", HloOpcode::kShiftRightLogical}, + {"sign", HloOpcode::kSign}, + {"sine", HloOpcode::kSin}, + {"slice", HloOpcode::kSlice}, + {"sort", HloOpcode::kSort}, + {"subtract", HloOpcode::kSubtract}, + {"tanh", HloOpcode::kTanh}, + {"trace", HloOpcode::kTrace}, + {"transpose", HloOpcode::kTranspose}, + {"tuple", HloOpcode::kTuple}, + {"while", HloOpcode::kWhile}}); + auto it = opcode_map->find(opcode_name); + if (it == opcode_map->end()) { + return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); + } + return it->second; +} + bool HloOpcodeIsComparison(HloOpcode opcode) { switch (opcode) { case HloOpcode::kGe: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 7b23249640b0dcfdd510caf27bf57bb1f2f6850e..8090e4c82ea5cea1c80ac75426a6fa6cf3ad9e5f 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -30,6 +31,7 @@ namespace xla { enum class HloOpcode { kAbs, kAdd, + kAtan2, kBatchNormGrad, kBatchNormInference, kBatchNormTraining, @@ -38,6 +40,7 @@ enum class HloOpcode { kCall, kCeil, kClamp, + kComplex, kConcatenate, kConstant, kConvert, @@ -57,14 +60,15 @@ enum class HloOpcode { kGe, kGetTupleElement, kGt, + kImag, kIndex, kInfeed, kIsFinite, kLe, kLog, - kLogicalAnd, - kLogicalNot, - kLogicalOr, + kAnd, + kNot, + kOr, kLt, kMap, kMaximum, @@ -76,6 +80,7 @@ enum class HloOpcode { kPad, kParameter, kPower, + kReal, kRecv, kReduce, kReducePrecision, @@ -88,6 +93,9 @@ enum class HloOpcode { kSelect, kSelectAndScatter, kSend, + kShiftLeft, + kShiftRightArithmetic, + kShiftRightLogical, kSign, kSin, kSlice, @@ -97,13 +105,15 @@ enum class HloOpcode { kTrace, kTranspose, kTuple, - kUpdate, kWhile, }; // Returns a string representation of the opcode. string HloOpcodeString(HloOpcode opcode); +// Returns a string representation of the opcode. +StatusOr StringToHloOpcode(const string& opcode_name); + inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { return os << HloOpcodeString(opcode); } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index ed7b6c71bc6619b0cb93f226eb10de1023749109..53bd46a641afcba1b9551895955742e74a9f374b 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -59,6 +59,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { for (auto& invariant_checker : invariant_checkers_) { VLOG(1) << " Invariant checker " << invariant_checker->name(); StatusOr changed_status = invariant_checker->Run(module); + VLOG(1) << " Invariant checker done " << invariant_checker->name(); if (!changed_status.ok()) { VLOG(2) << "Module failed invariant check:"; XLA_VLOG_LINES(2, module->ToString()); diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc new file mode 100644 index 0000000000000000000000000000000000000000..c3f74e253f7a7882ec1c72e0ce634017dd2f0957 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -0,0 +1,178 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_runner.h" + +#include +#include +#include + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +/*static*/ StatusOr> +HloRunner::ReadModuleFromHloProtoFile(const char* filename, + const DebugOptions& debug_options) { + HloProto proto; + TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), + filename, &proto)); + TF_ASSIGN_OR_RETURN( + HloModuleConfig config, + HloModule::CreateModuleConfigFromProto(proto.hlo_module())); + config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN(auto module, + HloModule::CreateFromProto(proto.hlo_module(), config)); + return std::move(module); +} + +// Define this in .cc file to avoid having to include eigen or forward declare +// these types in the header. +struct HloRunner::EigenThreadPoolWrapper { + std::unique_ptr pool; + std::unique_ptr device; +}; + +HloRunner::HloRunner() {} + +HloRunner::HloRunner(se::Platform* platform) { + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie(); + VLOG(1) << "Created HloRunner for platform: " << platform->Name(); +} + +HloRunner::~HloRunner() { + // Deallocate all the memory allocated during the tests. + for (auto& allocation : allocations_) { + backend().default_stream_executor()->Deallocate(&allocation); + } +} + +StatusOr HloRunner::Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments, + Shape* result_shape) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + backend().compiler()->Compile(std::move(module), + backend().default_stream_executor())); + + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + + ExecutableRunOptions run_options; + run_options.set_stream(&stream); + run_options.set_allocator(backend().memory_allocator()); + run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); + run_options.set_intra_op_thread_pool( + backend().eigen_intra_op_thread_pool_device()); + + HloExecutionProfile hlo_execution_profile; + ServiceExecutableRunOptions service_run_options( + run_options, backend().StreamBorrower(), + backend().inter_op_thread_pool()); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase result, + executable->ExecuteOnStream(&service_run_options, arguments, + &hlo_execution_profile)); + TF_RET_CHECK(stream.BlockHostUntilDone()); + + allocations_.push_back(result); + + *result_shape = executable->result_shape(); + + if (ShapeUtil::IsTuple(*result_shape)) { + // We must record element buffers of tuples as well to avoid leaks. + DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); + TF_ASSIGN_OR_RETURN( + std::vector element_buffers, + backend().transfer_manager()->ShallowCopyTupleFromDevice( + backend().default_stream_executor(), result, *result_shape)); + + // A tuple may contain the same buffer in more than one element. Keep track + // of the buffers already added to avoid duplicates in allocations_. + std::set added_opaques; + for (auto element_buffer : element_buffers) { + if (added_opaques.count(element_buffer.opaque()) == 0) { + CHECK(element_buffer.opaque() != nullptr); + added_opaques.insert(element_buffer.opaque()); + allocations_.push_back(element_buffer); + } + } + } + + return result; +} + +StatusOr HloRunner::TransferToDevice( + const Literal& literal) { + // Allocate memory on the device using the stream executor. + int64 allocation_size = + backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); + se::DeviceMemoryBase allocation = + backend().default_stream_executor()->AllocateArray( + allocation_size); + allocations_.push_back(allocation); + + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + backend().default_stream_executor(), literal, &allocation)); + + return allocation; +} + +StatusOr> HloRunner::TransferFromDevice( + const Shape& shape, se::DeviceMemoryBase device_base) { + auto literal = MakeUnique(); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromDevice( + backend().default_stream_executor(), device_base, shape, shape, + literal.get())); + return std::move(literal); +} + +StatusOr> HloRunner::ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + Shape result_shape; + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase device_base, + Execute(std::move(module), arguments, &result_shape)); + return TransferFromDevice(result_shape, device_base); +} + +Backend& HloRunner::backend() { + if (!backend_) { + backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie(); + VLOG(1) << "executing on platform " << backend().platform()->Name(); + } + return *backend_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..a4d7b653dbfbfdb169c07bca3e461147fd9d077a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -0,0 +1,114 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// A base class for running an HloModule. This executes the given HloModule on a +// certain backend directly without using the client interface. HloModule can be +// explicitly built, or loaded from a serialization file (e.g., hlo proto file). +class HloRunner { + public: + HloRunner(); + + HloRunner(::perftools::gputools::Platform* platform); + + ~HloRunner(); + + // Reads the binary proto file in xla.HloProto format, creates and returns the + // HloModule. + static StatusOr> ReadModuleFromHloProtoFile( + const char* filename, const DebugOptions& debug_options); + + // Executes the given module with given literals as input and returns the + // result as a Literal. The LiteralPtr type accepts Literal* or + // std::unique_ptr. + template + StatusOr> Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice literals); + + // Executes the given module and returns a global data handle. + StatusOr Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice + arguments, + Shape* result_shape); + + // Transfers the given literal to the device and returns the data handle. + StatusOr TransferToDevice( + const Literal& literal); + + // Transfers the array referred to by the given handle from the device and + // returns as a Literal. + StatusOr> TransferFromDevice( + const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); + + // Executes the given module and return the result as a Literal. + StatusOr> ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice + arguments); + + // If backend is not created in the constructor, creates and returns the + // default backend. If creation fails, crashes the program. + // + // This creates the backend lazily so it's possible to instantiate an + // HloRunner in a program without any backends linked in. + Backend& backend(); + + private: + struct EigenThreadPoolWrapper; + + std::vector allocations_; + + std::unique_ptr thread_pool_wrapper_; + + std::unique_ptr backend_; +}; + +template +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice literals) { + std::vector arguments; + for (const auto& literal : literals) { + TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument, + TransferToDevice(*literal)); + arguments.push_back(argument); + } + return ExecuteAndTransfer(std::move(module), arguments); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 3f6d89f24f4ec76d913611d03dd28b93a09d34a1..2007a8f11d823044da01b5f4ec7a354c56776b12 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -56,6 +56,8 @@ TensorShapeProto GetTensorShape(const HloInstruction* instruction) { return tensor_shape; } +string GetDeviceName(int device) { return StrCat("/device/XLA:", device); } + } // namespace void CleanNodeName(string* name) { @@ -178,6 +180,10 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, case HloOpcode::kCustomCall: attrs["custom_call_target"].set_s(instruction->custom_call_target()); break; + case HloOpcode::kSend: + case HloOpcode::kRecv: + attrs["channel_id"].set_i(instruction->channel_id()); + break; default: break; } @@ -192,6 +198,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { NodeDef* node_def = graph_def_.add_node(); node_def->set_name(GetNodeNameForInstruction(instruction)); node_def->set_op(GetOpDefName(instruction)); + if (instruction->device_assignment().has_device()) { + node_def->set_device( + GetDeviceName(instruction->device_assignment().device())); + } SetNodeAttrs(instruction, node_def); if (instruction->opcode() == HloOpcode::kFusion) { for (auto* fused_instruction : instruction->fused_instructions()) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 35dff4a957f023d0f34082d7db1b6a6ade9c15f8..f3a098057b9f454b2ae3366cc39fdfa6a57c83f6 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -64,6 +64,10 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleConvert(HloInstruction* convert) override { + if (ShapeUtil::ElementIsComplex(convert->operand(0)->shape())) { + TF_RET_CHECK(ShapeUtil::ElementIsComplex(convert->shape())) + << "Unsupported complex->real kConvert"; + } return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc index 0682434bfbac42ac3839c7066f167b7505dfdd0a..40df0dc355c62236c3915e344db515b45f899403 100644 --- a/tensorflow/compiler/xla/service/inliner.cc +++ b/tensorflow/compiler/xla/service/inliner.cc @@ -76,8 +76,7 @@ Status InlinerVisitor::HandleMap( // Only inlining functions that are simply a single operation until a better // profitability model for inlining is defined. if (hlo_query::AllOperandsAreParameters(root)) { - if (root.opcode() == HloOpcode::kUpdate || - root.opcode() == HloOpcode::kFusion || + if (root.opcode() == HloOpcode::kFusion || root.opcode() == HloOpcode::kIndex || root.opcode() == HloOpcode::kParameter || root.opcode() == HloOpcode::kTrace) { @@ -90,8 +89,12 @@ Status InlinerVisitor::HandleMap( // different than the map shape. Hence, a broadcast is needed, else the // cloned operand with new shape and operands work. if (root.opcode() != HloOpcode::kConstant) { + std::vector params; + for (int64 o = 0; o < root.operands().size(); o++) { + params.push_back(operands[root.operand(o)->parameter_number()]); + } HloInstruction* placed_instruction = computation_->AddInstruction( - root.CloneWithNewOperands(map->shape(), operands)); + root.CloneWithNewOperands(map->shape(), params)); TF_RETURN_IF_ERROR( computation_->ReplaceInstruction(map, placed_instruction)); } else { diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 9d845c5545680f1a5389fa89af02eb723203a5f7..7aa1c7c8358318d02a000d968a2672123400ad6e 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -108,5 +108,44 @@ TEST_F(InlinerTest, MapConstant) { LiteralTestUtil::ExpectEqual(*result, *expected); } +TEST_F(InlinerTest, MapSubtractOppositeOrder) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + // Note that the parameter ordinals are in the opposite order to their + // position as operands + auto max_builder = HloComputation::Builder(TestName()); + auto param1 = max_builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "x")); + auto param2 = max_builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "y")); + max_builder.AddInstruction(HloInstruction::CreateBinary( + param1->shape(), HloOpcode::kSubtract, param1, param2)); + auto max_f32 = max_builder.Build(); + + auto builder = HloComputation::Builder("MapSubFunction"); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewModule(); + hlo_module->AddEmbeddedComputation(std::move(max_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + Inliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), + op::Subtract(rhs, lhs)); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto expected = Literal::CreateR1({3, 1, -1, -3}); + LiteralTestUtil::ExpectEqual(*result, *expected); +} + + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 7a273816427f536c71da61a28d7e60d0544f9bf2..fae3ca8ad22cb2d1d7bce03664c4e83c014c6de3 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -32,17 +32,16 @@ namespace xla { const HloInstruction& instruction) { switch (instruction.opcode()) { // Cheap instructions. - case HloOpcode::kAbs: case HloOpcode::kAdd: case HloOpcode::kBitcast: case HloOpcode::kBroadcast: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kComplex: case HloOpcode::kConcatenate: case HloOpcode::kConstant: case HloOpcode::kConvert: case HloOpcode::kCopy: - case HloOpcode::kCos: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: @@ -50,12 +49,13 @@ namespace xla { case HloOpcode::kGe: case HloOpcode::kGetTupleElement: case HloOpcode::kGt: + case HloOpcode::kImag: case HloOpcode::kInfeed: case HloOpcode::kIsFinite: case HloOpcode::kLe: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalNot: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -64,20 +64,30 @@ namespace xla { case HloOpcode::kNegate: case HloOpcode::kOutfeed: case HloOpcode::kPad: + case HloOpcode::kReal: case HloOpcode::kReducePrecision: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: - case HloOpcode::kSign: - case HloOpcode::kSin: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: case HloOpcode::kTranspose: case HloOpcode::kTuple: return false; + // Cheap instructions for reals, but expensive for complex. + case HloOpcode::kAbs: + case HloOpcode::kCos: + case HloOpcode::kSign: + case HloOpcode::kSin: + return ShapeUtil::ElementIsComplex(instruction.shape()); + // Expensive instructions. + case HloOpcode::kAtan2: case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: @@ -102,7 +112,6 @@ namespace xla { case HloOpcode::kSort: case HloOpcode::kTanh: case HloOpcode::kTrace: - case HloOpcode::kUpdate: case HloOpcode::kWhile: case HloOpcode::kSend: case HloOpcode::kRecv: diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index c8d02834f43a747980d084be37602bc56db74b98..93ea2f736742eab06ee0d7e881ee7c51daee9878 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -88,7 +88,7 @@ StatusOr> InterpreterCompiler::Compile( StatusOr>> InterpreterCompiler::Compile( std::vector> /*hlo_modules*/, - std::vector /*stream_execs*/) { + std::vector> /*stream_execs*/) { return tensorflow::errors::Unimplemented( "Compilation of multiple HLO modules is not supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index 13db38ab60a07bdf476227c9b9e818dfe2cdcc6c..cfdc9b6256569b0137784b0d1db846a5f2339a5d 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -49,7 +49,8 @@ class InterpreterCompiler : public Compiler { StatusOr>> Compile( std::vector> hlo_modules, - std::vector stream_exec) override; + std::vector> + stream_exec) override; StatusOr>> CompileAheadOfTime(std::vector> hlo_modules, diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 8fd330fda715276894c5e4e5c9fc7f8c4416105d..7eda7c2284c2457703fcfcd4226172e41dd4ae01 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -732,7 +732,8 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( // dimension bound is 1 in the operand shape, there may be several such // layouts. So if 'output_layout' is the default layout, try if the // reshape is a bitcast when using the same layout. This may avoid copy - // operations. + // operations. For similar reasons, if the operand and output have the same + // rank, try to match the operand's layout to the output. if (ShapeUtil::TrueRank(operand->shape()) == 1 && ShapeUtil::Rank(instruction->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. @@ -748,6 +749,13 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { return MakeUnique(operand_shape.layout()); } + if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { + *operand_shape.mutable_layout() = output_layout; + if (ShapeUtil::ReshapeIsBitcast(operand_shape, + output_shape_with_layout)) { + return MakeUnique(output_layout); + } + } auto aligned_operand_shape = ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); if (aligned_operand_shape) { @@ -796,7 +804,8 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( // dimension bound is 1 in the user shape, there may be several such // layouts. So if 'operand_layout' is the default layout, try if the // reshape is a bitcast when using the same layout. This may avoid copy - // operations. + // operations. For similar reasons, if the operand and output have the same + // rank, try to match the outputs's layout to the operand. if (ShapeUtil::Rank(operand->shape()) == 1 && ShapeUtil::TrueRank(user->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. @@ -812,6 +821,13 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { return MakeUnique(output_shape.layout()); } + if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { + *output_shape.mutable_layout() = operand_layout; + if (ShapeUtil::ReshapeIsBitcast(output_shape, + operand_shape_with_layout)) { + return MakeUnique(operand_layout); + } + } auto aligned_user_shape = ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); if (aligned_user_shape) { @@ -1180,8 +1196,6 @@ Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, // to match the layout of its corresponding fusion instruction operand. Also, // set the layout of the fused root to match the layout of the fusion // instruction itself. -// Fused GetTupleElement requires a layout so that TBAA metadata for the tuple -// element array pointer load can be added. Status SetFusionLayouts(HloInstruction* fusion) { TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion); for (auto* fused_instruction : fusion->fused_instructions()) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index f498f950573fbb1b3594cc7ab9a57fe979fa4c60..075d4a1ab5e5f39394ade393d21525ca3e97136e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -93,7 +93,6 @@ cc_library( deps = [ ":ir_array", ":llvm_loop", - ":ops", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -112,7 +111,7 @@ cc_library( ":ir_array", ":llvm_util", ":loop_emitter", - ":ops", + ":tuple_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -128,6 +127,23 @@ cc_library( name = "ops", srcs = ["ops.cc"], hdrs = ["ops.h"], + deps = [ + ":fused_ir_emitter", + ":ir_array", + ":llvm_util", + ":loop_emitter", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", + "//tensorflow/compiler/xla/service/gpu:partition_assignment", + ], +) + +cc_library( + name = "tuple_ops", + srcs = ["tuple_ops.cc"], + hdrs = ["tuple_ops.h"], deps = [ ":ir_array", ":llvm_util", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 5e28e37600c18a351e8647d48119f073277f56e1..bdddc232ef74dfa37e2d5cc780b0fe11e7bc8e76 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -92,7 +92,16 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, llvm::MDNode* AliasAnalysis::GetAliasDomain() { llvm::MDBuilder metadata_builder(*context_); if (alias_domain_ == nullptr) { - alias_domain_ = metadata_builder.createAnonymousAliasScopeDomain(); + // We use createAliasScopeDomain rather than createAnonymousAliasScopeDomain + // so that when functions get inlined, we continue using the one domain, + // rather than duplicating it (and thus having two AA domains in one + // function). + // + // A side-effect of this is that if you ever compile two HLO modules in the + // same LLVM module, they'll have the same alias scope domain. This isn't a + // problem because the two HLO modules will never interact with one another. + alias_domain_ = + metadata_builder.createAliasScopeDomain("XLA global AA domain"); } return alias_domain_; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 7d1fad753e0d94f7d88b824ed57d52890a48b1dd..a2af2580ffb58df08b1a6a4b26a06f03978f2228 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" @@ -75,7 +75,7 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { Status FusedIrEmitter::HandleConstant(HloInstruction* constant, const Literal& literal) { llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, ir_builder_); + llvm_ir::ConvertLiteralToIrConstant(literal, module_); llvm::GlobalVariable* global = new llvm::GlobalVariable( *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, @@ -101,7 +101,7 @@ Status FusedIrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element, // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( get_tuple_element->shape(), get_tuple_element->tuple_index(), - /*alignment=*/1, it->second, ir_builder_); + /*alignment=*/1, it->second, ir_builder_, module_); gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr)); // Emit code to read base tuple element array (if non-tuple shaped). if (!ShapeUtil::IsTuple(get_tuple_element->shape())) { @@ -134,7 +134,7 @@ Status FusedIrEmitter::HandleTuple( std::vector operand_elemental_ir_types; for (HloInstruction* operand : operands) { operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType( - operand->shape().element_type(), ir_builder_)); + operand->shape().element_type(), module_)); } generators_[tuple] = [=](const IrArray::Index& index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index a24e104067f19e45ab2566beedbb8217913bad12..a44da5137857533b9015853b62844279e58104ff 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -42,7 +42,8 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { ElementalIrEmitter* elemental_emitter) : parameter_arrays_(parameter_arrays), elemental_emitter_(elemental_emitter), - ir_builder_(elemental_emitter->ir_builder()) {} + ir_builder_(elemental_emitter->ir_builder()), + module_(elemental_emitter->module()) {} Status DefaultAction(HloInstruction* hlo) override; @@ -85,6 +86,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { // Borrowed llvm::IRBuilder<>* ir_builder_; + llvm::Module* module_; // Map from instruction pointers to functions to generate elements of their // outputs diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index e36c791c1a52f4e0699cc15ef913fbd2bdcca557..e3f98ac13e76f0df465066422ca7918a0f218b60 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -229,9 +229,11 @@ llvm::Value* IrArray::EmitArrayElementAddress( } if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) { + llvm::Module* module = + ir_builder->GetInsertBlock()->getParent()->getParent(); return ir_builder->CreateInBoundsGEP( ir_builder->CreateBitCast( - base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), ir_builder) + base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), module) ->getPointerTo()), {index.linear()}, llvm_ir::AsStringRef(name)); } @@ -268,8 +270,6 @@ llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder, name); llvm::LoadInst* load = ir_builder->CreateLoad(element_address); - llvm_ir::SetTbaaForInstruction(load, GetShape(), - /*is_pointer_to=*/false); AnnotateLoadStoreInstructionWithMetadata(load); return load; } @@ -278,14 +278,13 @@ void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, llvm::IRBuilder<>* ir_builder) const { llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder); llvm::StoreInst* store = ir_builder->CreateStore(value, element_address); - llvm_ir::SetTbaaForInstruction(store, GetShape(), - /*is_pointer_to=*/false); AnnotateLoadStoreInstructionWithMetadata(store); } IrArray IrArray::CastToShape(const Shape& new_shape, llvm::IRBuilder<>* ir_builder) const { - llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, ir_builder); + llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); + llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module); return IrArray( ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 4a7d2b48f786119067b3ad992410c985ccf80829..5dff4b5778970dd473c5f158b3828a850847d1ff 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/Target/TargetOptions.h" @@ -38,6 +39,19 @@ limitations under the License. namespace xla { namespace llvm_ir { +namespace { + +// Note, this function is only useful in an insertion context; in a global +// (e.g. constants) context it will CHECK fail. +llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* ir_builder) { + auto block = CHECK_NOTNULL(ir_builder->GetInsertBlock()); + auto fn = CHECK_NOTNULL(block->getParent()); + auto module = CHECK_NOTNULL(fn->getParent()); + return module; +} + +} // namespace + string AsString(const std::string& str) { return string(str.data(), str.length()); } @@ -63,7 +77,7 @@ llvm::Value* EmitCallToIntrinsic( for (auto type : overloaded_types) { types.push_back(type); } - llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); + llvm::Module* module = ModuleFromIRBuilder(ir_builder); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(module, intrinsic_id, types); std::vector operands_vec; @@ -119,38 +133,53 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, } llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, - llvm::IRBuilder<>* ir_builder) { + llvm::Module* module) { switch (element_type) { case PRED: case S8: case U8: - return ir_builder->getInt8Ty(); + return llvm::Type::getInt8Ty(module->getContext()); case S16: case U16: - return ir_builder->getInt16Ty(); + return llvm::Type::getInt16Ty(module->getContext()); case S32: case U32: - return ir_builder->getInt32Ty(); + return llvm::Type::getInt32Ty(module->getContext()); case S64: case U64: - return ir_builder->getInt64Ty(); + return llvm::Type::getInt64Ty(module->getContext()); case F32: - return ir_builder->getFloatTy(); + return llvm::Type::getFloatTy(module->getContext()); case F64: - return ir_builder->getDoubleTy(); + return llvm::Type::getDoubleTy(module->getContext()); + case C64: { + auto cplx_t = module->getTypeByName("complex64"); + if (cplx_t == nullptr) { + // C++ standard dictates the memory layout of std::complex is contiguous + // real followed by imaginary. C++11 section 26.4 [complex.numbers]: + // If z is an lvalue expression of type cv std::complex then the + // expression reinterpret_cast(z) shall be well-formed, + // reinterpret_cast(z)[0] shall designate the real part of + // z, and reinterpret_cast(z)[1] shall designate the + // imaginary part of z. + return llvm::StructType::create( + "complex64", llvm::Type::getFloatTy(module->getContext()), + llvm::Type::getFloatTy(module->getContext())); + } + return cplx_t; + } // A Tuple contains an array of pointers. Use i8*. case TUPLE: // An Opaque is like a void*, use i8*. case OPAQUE: - return ir_builder->getInt8PtrTy(); + return llvm::Type::getInt8PtrTy(module->getContext()); default: LOG(FATAL) << "unsupported type " << element_type; } } -llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder) { - llvm::Type* result_type = - PrimitiveTypeToIrType(shape.element_type(), ir_builder); +llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { + llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); if (ShapeUtil::IsTuple(shape)) { // A tuple buffer is an array of pointers. result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); @@ -197,10 +226,10 @@ namespace { // value down to zero). llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, std::vector* multi_index, - llvm::IRBuilder<>* ir_builder) { + llvm::Module* module) { const Shape& shape = literal.shape(); llvm::Type* ir_element_type = - llvm_ir::PrimitiveTypeToIrType(shape.element_type(), ir_builder); + llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module); if (dimension_index == -1) { // Base case of the recursion. Index into the data field of the protobuf // with the multi index. @@ -238,6 +267,16 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); break; + case C64: { + complex64 x = literal.Get(*multi_index); + value = llvm::ConstantStruct::get( + static_cast(ir_element_type), + llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), + x.real()), + llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), + x.imag())); + break; + } default: LOG(FATAL) << "unsupported type " << shape.element_type(); } @@ -256,8 +295,8 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, std::vector elements; for (int64 i = 0; i < shape.dimensions(dimension); ++i) { (*multi_index)[dimension] = i; - elements.push_back(LiteralToConstant(literal, dimension_index - 1, - multi_index, ir_builder)); + elements.push_back( + LiteralToConstant(literal, dimension_index - 1, multi_index, module)); } llvm::Type* element_type; @@ -279,11 +318,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, } // namespace llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, - llvm::IRBuilder<>* ir_builder) { + llvm::Module* module) { std::vector multi_index(ShapeUtil::Rank(literal.shape()), 0); llvm::Constant* value = LiteralToConstant( literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, - &multi_index, ir_builder); + &multi_index, module); return value; } @@ -380,7 +419,8 @@ llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1 // arrays. So we extend it to i8 so that it's addressable. return ir_builder->CreateZExt( - comparison_result, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder)); + comparison_result, + llvm_ir::PrimitiveTypeToIrType(PRED, ModuleFromIRBuilder(ir_builder))); } // Internal helper that is called from emitted code to log an int64 value with a @@ -402,13 +442,6 @@ void EmitLogging(const char* tag, llvm::Value* value, {ir_builder->getInt64(tensorflow::bit_cast(tag)), value}); } -void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, - bool is_pointer_to) { - // TODO(b/62903316): TBAA metadata causes LLVM to miscompile generated code, - // most likely because the generated metadata is incorrect. Disable TBAA - // metadata while we resolve this. -} - void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) { llvm::LLVMContext& context = load->getContext(); llvm::Type* int64_ty = llvm::Type::getInt64Ty(context); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 5af62b056ef19cb3c06e86dcb3a84e21fb2b7fee..304192b58e9331c2544f973bf65299111122aea8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -127,11 +127,11 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, // Returns the LLVM type which represents the given XLA primitive type. llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, - llvm::IRBuilder<>* ir_builder); + llvm::Module* module); // Returns the LLVM type which represents the given XLA shape. For example, // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]]. -llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder); +llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module); // Returns a value that represents a pointer to a global string constant that // encodes the shape as a serialized protobuf. @@ -149,7 +149,7 @@ StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, // Converts a given literal to an IR Constant. Literals have known constant // values at IR emission time. llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, - llvm::IRBuilder<>* ir_builder); + llvm::Module* module); // Inserts an allocate of the requested type at the entry point of the // function that the builder is currently building. The insert point @@ -227,12 +227,6 @@ llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* ir_builder); -// Adds TBAA metadata to a load or store instruction using the given shape as -// it's type. The is_pointer_to parameter is used to indicate whether or not -// this instruction loads or stores a pointer to an array. -void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, - bool is_pointer_to); - // Adds alignment metadata to a load instruction using the given alignment. // The alignment refers to the result of the load, not the load itself. void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment); diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 8bba1776d19005292da48705df2436b6f30e0f2d..6fa4cd08c9e0ac30b83c0e2b49d98d930c2e15df 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index ac562e231c8f56184363d6e186c18847d67435ce..34899b7400464e4f4f97d301f35ed3b7b083bca1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -14,86 +14,167 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" - -#include -#include -#include - -#include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" namespace xla { namespace llvm_ir { -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder) { - CHECK(ShapeUtil::IsScalar(pred.GetShape())); - - llvm::LoadInst* pred_value = - ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = ir_builder->CreateICmpNE( - pred_value, - llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, ir_builder), 0), - "boolean_predicate"); - - VLOG(2) << "HandleSelect for tuple:"; - VLOG(2) << " pred_value: " << DumpToString(*pred_value); - VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); - - for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - std::vector element_index = {ir_builder->getInt64(0), - ir_builder->getInt64(i)}; - llvm::Value* on_true_element_address = - ir_builder->CreateInBoundsGEP(on_true, element_index); - llvm::Value* on_true_element = ir_builder->CreateLoad( - on_true_element_address, - tensorflow::strings::Printf("on_true_element_%d", i).c_str()); - llvm::Value* on_false_element_address = - ir_builder->CreateInBoundsGEP(on_false, element_index); - llvm::Value* on_false_element = ir_builder->CreateLoad( - on_false_element_address, - tensorflow::strings::Printf("on_false_element_%d", i).c_str()); - - llvm::Value* output_element_address = - ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); - ir_builder->CreateStore( - ir_builder->CreateSelect( - pred_cond, on_true_element, on_false_element, - tensorflow::strings::Printf("select_output_element_%d", i).c_str()), - output_element_address); - } +bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, + const BufferAssignment& assignment) { + CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode()); + const HloInstruction* operand = dynamic_update_slice->operand(0); + return assignment.HasTopLevelAllocation(dynamic_update_slice) && + assignment.HasTopLevelAllocation(operand) && + assignment.SharesTopLevelSlice(dynamic_update_slice, operand); } -void EmitTuple(IrArray tuple, - tensorflow::gtl::ArraySlice operands, - llvm::IRBuilder<>* ir_builder) { - for (size_t i = 0; i < operands.size(); ++i) { - ir_builder->CreateStore( - ir_builder->CreatePointerCast(operands[i], - PrimitiveTypeToIrType(TUPLE, ir_builder)), - ir_builder->CreateInBoundsGEP( - tuple.GetBasePointer(), - {ir_builder->getInt64(0), ir_builder->getInt64(i)})); +// Shared implementation of EmitDynamicUpdateSliceInPlace and +// EmitFusedDynamicUpdateSliceInPlace. +// +// Emits a sequential loop if launch_dimensions is null. +static Status EmitDynamicUpdateSliceInPlaceImpl( + const Shape& update_shape, const ElementGenerator& start_indices_generator, + ElementGenerator update_array_generator, const IrArray& output_array, + const gpu::LaunchDimensions* launch_dimensions, + tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) { + const Shape& output_shape = output_array.GetShape(); + + // Read start indices from start_indices_generator. + const int64 rank = ShapeUtil::Rank(output_shape); + IrArray::Index start_index(rank); + for (int64 i = 0; i < rank; ++i) { + IrArray::Index dim_index({ir_builder->getInt64(i)}); + TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); } + + auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { + // Calculate output_index, where we'll write the value from update. For + // each dimension, + // + // output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size. + // + IrArray::Index output_index(rank); + for (int64 i = 0; i < rank; ++i) { + llvm::Value* dim_size = llvm::ConstantInt::get( + update_index[i]->getType(), output_shape.dimensions(i)); + llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast( + start_index[i], update_index[i]->getType()); + output_index[i] = ir_builder->CreateURem( + ir_builder->CreateAdd(start_index0, update_index[i]), dim_size); + } + + // Do output[output_index] = update[update_index]. + TF_ASSIGN_OR_RETURN(llvm::Value * update_data, + update_array_generator(update_index)); + output_array.EmitWriteArrayElement(output_index, update_data, ir_builder); + return Status::OK(); + }; + + if (launch_dimensions != nullptr) { + return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape, + *launch_dimensions, ir_builder) + .EmitLoop(name); + } + return LoopEmitter(loop_body_emitter, update_shape, ir_builder) + .EmitLoop(name); +} + +Status EmitDynamicUpdateSliceInPlace( + tensorflow::gtl::ArraySlice operand_arrays, + const IrArray& output_array, tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder) { + VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; + + // No need to use operand_arrays[0], the input array of the + // dynamic-update-slice, because we know it aliases the op's output. + IrArray update_array = operand_arrays[1]; + IrArray start_indices_array = operand_arrays[2]; + Shape output_shape = output_array.GetShape(); + Shape update_shape = update_array.GetShape(); + + ElementGenerator start_indices_generator = [&](const IrArray::Index& index) { + return start_indices_array.EmitReadArrayElement(index, ir_builder); + }; + ElementGenerator update_array_generator = [&](const IrArray::Index& index) { + return update_array.EmitReadArrayElement(index, ir_builder); + }; + + return EmitDynamicUpdateSliceInPlaceImpl( + update_shape, start_indices_generator, update_array_generator, + output_array, /*launch_dimensions=*/nullptr, name, ir_builder); +} + +// Shared implementation for EmitFusedDynamicUpdateSliceInPlace and +// EmitParallelFusedDynamicUpdateSliceInPlace. +// +// Emits a sequential loop if launch_dimensions is null. +static Status EmitFusedDynamicUpdateSliceInPlaceImpl( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + const gpu::LaunchDimensions* launch_dimensions, + llvm::IRBuilder<>* ir_builder) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for " + << fusion->ToShortString(); + + auto* dynamic_update_slice = fusion->fused_expression_root(); + + const auto* update = dynamic_update_slice->operand(1); + const auto* start_indices = dynamic_update_slice->operand(2); + Shape update_shape = update->shape(); + + // Our in-place dynamic-update-slice implementation emits a loop over + // update_shape. To emit a cache-friendly loop, we need to know that shape's + // layout. + // + // update_shape is inside a fusion node -- it's never materialized in memory + // and thus doesn't have a layout. In this case we use the layout of the + // fusion node for iteration, since that corresponds to the order in memory of + // the buffer we'll be writing to. + // + // (This isn't necessarily optimal; in some cases it might be faster to peek + // through the chain of ops that gives us the update operand and use the + // layout of its source buffer(s). But this is no worse than we do with + // fusion elsewhere.) + TF_RETURN_IF_ERROR( + LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape)); + + // Create element generators for update and start_indices. + FusedIrEmitter fused_emitter(fusion_operand_arrays, elemental_emitter); + TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); + ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); + ElementGenerator start_indices_generator = + fused_emitter.GetGenerator(start_indices); + + return EmitDynamicUpdateSliceInPlaceImpl( + update_shape, start_indices_generator, update_array_generator, + fusion_output_array, launch_dimensions, IrName(fusion), ir_builder); +} + +Status EmitFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + llvm::IRBuilder<>* ir_builder) { + return EmitFusedDynamicUpdateSliceInPlaceImpl( + fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, + /*launch_dimensions=*/nullptr, ir_builder); } -llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, - int alignment, llvm::Value* operand, - llvm::IRBuilder<>* ir_builder) { - llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP( - operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)}); - llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr); - SetTbaaForInstruction(src_buffer, target_shape, /*is_pointer_to=*/true); - SetAlignmentMetadataForLoad(src_buffer, alignment); - llvm::Type* element_type = ShapeToIrType(target_shape, ir_builder); - llvm::Value* ret_val = - ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo()); - return ret_val; +Status EmitParallelFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + const gpu::LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* ir_builder) { + return EmitFusedDynamicUpdateSliceInPlaceImpl( + fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, + &launch_dimensions, ir_builder); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/ops.h index 4e1d9d1080b3a5c8d8a09145f68bcff9d329c929..11e84d9cb5defbcb87a8f696d56c139686c960d8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.h @@ -13,67 +13,68 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/types.h" + +// Utilities related to emitting LLVM IR for various HLO ops. namespace xla { namespace llvm_ir { -// Selection among tuples is special in how it's lowered, because a tuple is not -// an HLO array. -// -// tuple_on_true tuple_on_false -// | | -// V V -// ------------------------ ------------------------ -// | address of element 0 | | address of element 0 | -// |----------------------| |----------------------| -// | address of element 1 | | address of element 1 | -// |----------------------| |----------------------| -// | address of element 2 | | address of element 2 | -// ------------------------ ------------------------ -// \ / -// \ / -// ---------- -// pred ---------> | select | -// ---------- -// | -// V -// output ----> ------------------------ -// | address of element 0 | -// |----------------------| -// | address of element 1 | -// |----------------------| -// | address of element 2 | -// ------------------------ +// Checks if we can emit code for the given DynamicUpdateSlice node that updates +// its input in place. Returns true if the dynamic-update-slice's +// array-to-be-updated and output share the same BufferAllocation::Slice. // -// Only the addresses are copied to the output. For each element, we emit a copy -// of the address from the corresponding element in either -// tuple_on_true or tuple_on_false: -// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder); +// dynamic_update_slice must be a DynamicUpdateSlice op. +bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, + const BufferAssignment& assignment); + +// Checks if the given fusion node is amenable to being implemented by +// EmitFusedDynamicUpdateSliceInPlace. +inline bool CanEmitFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, const BufferAssignment& assignment) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop && + fusion->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice && + CanUpdateDynamicSliceInPlace(fusion->fused_expression_root(), + assignment); +} + +// Emits IR for running the given dynamic-update-slice op in-place -- that is, +// where the input and output buffers share the same slice, so we can simply +// modify the input/output buffer without touching any of the other elements. +Status EmitDynamicUpdateSliceInPlace( + tensorflow::gtl::ArraySlice operand_arrays, + const IrArray& output_array, tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder); + +// Given a loop-fusion node whose root is a dynamic-update-slice op whose +// array-to-be-updated and output share the same buffer slice, emits +// (sequential) code for a fusion node that does the dynamic-update-slice in +// place. +Status EmitFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + llvm::IRBuilder<>* ir_builder); -// A tuple is an array of pointers, one for each operand. Each pointer points to -// the output buffer of its corresponding operand. -void EmitTuple(IrArray tuple, - tensorflow::gtl::ArraySlice operands, - llvm::IRBuilder<>* ir_builder); +// Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with +// the given launch dimensions. +Status EmitParallelFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + const gpu::LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* ir_builder); -// A tuple is an array of pointers, one for each operand. Each pointer points to -// the output buffer of its corresponding operand. A GetTupleElement instruction -// forwards the pointer to underlying tuple element buffer at the given index. -// Returns an llvm value representing a pointer to the tuple element buffer. -llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, - int alignment, llvm::Value* operand, - llvm::IRBuilder<>* ir_builder); } // namespace llvm_ir } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..3a21eda35757aa706565ee4a5286eee1acea117b --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -0,0 +1,110 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" + +#include +#include +#include + +#include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace llvm_ir { + +void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, + llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, + llvm::Module* module) { + CHECK(ShapeUtil::IsScalar(pred.GetShape())); + + llvm::LoadInst* pred_value = + ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ir_builder->CreateICmpNE( + pred_value, + llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, module), 0), + "boolean_predicate"); + + VLOG(2) << "HandleSelect for tuple:"; + VLOG(2) << " pred_value: " << DumpToString(*pred_value); + VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); + + for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { + std::vector element_index = {ir_builder->getInt64(0), + ir_builder->getInt64(i)}; + llvm::Value* on_true_element_address = + ir_builder->CreateInBoundsGEP(on_true, element_index); + llvm::Value* on_true_element = ir_builder->CreateLoad( + on_true_element_address, + tensorflow::strings::Printf("on_true_element_%d", i).c_str()); + llvm::Value* on_false_element_address = + ir_builder->CreateInBoundsGEP(on_false, element_index); + llvm::Value* on_false_element = ir_builder->CreateLoad( + on_false_element_address, + tensorflow::strings::Printf("on_false_element_%d", i).c_str()); + + llvm::Value* output_element_address = + ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); + ir_builder->CreateStore( + ir_builder->CreateSelect( + pred_cond, on_true_element, on_false_element, + tensorflow::strings::Printf("select_output_element_%d", i).c_str()), + output_element_address); + } +} + +void EmitTuple(IrArray tuple, + tensorflow::gtl::ArraySlice operands, + llvm::IRBuilder<>* ir_builder, llvm::Module* module) { + for (size_t i = 0; i < operands.size(); ++i) { + auto* store = ir_builder->CreateStore( + ir_builder->CreatePointerCast(operands[i], + PrimitiveTypeToIrType(TUPLE, module)), + ir_builder->CreateInBoundsGEP( + tuple.GetBasePointer(), + {ir_builder->getInt64(0), ir_builder->getInt64(i)})); + tuple.AnnotateLoadStoreInstructionWithMetadata(store); + } +} + +llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, + int alignment, llvm::Value* operand, + llvm::IRBuilder<>* ir_builder, + llvm::Module* module) { + llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP( + operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)}); + llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr); + + // Mark the loaded pointer as dereferenceable if we know its shape. + if (!ShapeUtil::IsOpaque(target_shape)) { + SetDereferenceableMetadataForLoad( + src_buffer, + ByteSizeOf(target_shape, src_buffer->getModule()->getDataLayout())); + } + SetAlignmentMetadataForLoad(src_buffer, alignment); + + llvm::Type* element_type = ShapeToIrType(target_shape, module); + llvm::Value* ret_val = + ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo()); + return ret_val; +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..dbf9a140068b60505f6798360438f709bfd3feba --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +// Utilities for emitting LLVM IR related to HLO tuples. + +namespace xla { +namespace llvm_ir { + +// Selection among tuples is special in how it's lowered, because a tuple is not +// an HLO array. +// +// tuple_on_true tuple_on_false +// | | +// V V +// ------------------------ ------------------------ +// | address of element 0 | | address of element 0 | +// |----------------------| |----------------------| +// | address of element 1 | | address of element 1 | +// |----------------------| |----------------------| +// | address of element 2 | | address of element 2 | +// ------------------------ ------------------------ +// \ / +// \ / +// ---------- +// pred ---------> | select | +// ---------- +// | +// V +// output ----> ------------------------ +// | address of element 0 | +// |----------------------| +// | address of element 1 | +// |----------------------| +// | address of element 2 | +// ------------------------ +// +// Only the addresses are copied to the output. For each element, we emit a copy +// of the address from the corresponding element in either +// tuple_on_true or tuple_on_false: +// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] +void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, + llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, + llvm::Module* module); + +// A tuple is an array of pointers, one for each operand. Each pointer points to +// the output buffer of its corresponding operand. +void EmitTuple(IrArray tuple, + tensorflow::gtl::ArraySlice operands, + llvm::IRBuilder<>* ir_builder, llvm::Module* module); + +// A tuple is an array of pointers, one for each operand. Each pointer points to +// the output buffer of its corresponding operand. A GetTupleElement instruction +// forwards the pointer to underlying tuple element buffer at the given index. +// Returns an llvm value representing a pointer to the tuple element buffer. +llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, + int alignment, llvm::Value* operand, + llvm::IRBuilder<>* ir_builder, + llvm::Module* module); +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 069f85af721228c8f5d40cf243eea7f1e5173c62..a0d08c288dbcc45e83a36ce7b094b04a9dbae532 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -23,7 +23,24 @@ namespace xla { string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { string root = prefix.empty() ? "name" : prefix.ToString(); - int* count = &(generated_names_[root]); + + // Strip away numeric suffix (if any). Only recognize separator if it is in + // the middle of the name. + size_t separator_index = root.rfind(separator_); + if (separator_index != string::npos && (separator_index > 0) && + (separator_index < root.size() - 1)) { + string after_suffix = root.substr(separator_index + 1); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + // Remove numeric suffix from root. + root = root.substr(0, separator_index); + // Update count to at least the numeric suffix value to avoid future + // colisions with this name. + generated_names_[root] = std::max(generated_names_[root], numeric_suffix); + } + } + + int64* count = &(generated_names_[root]); if (*count == 0) { *count = 1; return root; @@ -31,9 +48,6 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { tensorflow::strings::StrAppend(&root, separator_, *count); // Increment lookup under old 'root' name. (*count)++; - // Initialize count under new 'root' name. - count = &(generated_names_[root]); - *count = 1; return root; } } diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index b0944adbc1d98fd88c550cc8b53cf399e43535e6..ed379b52258463b960dea788721c2c4325ef0260 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -43,7 +43,7 @@ class NameUniquer { // Map from name prefix to the number of names generated using that prefix // so far. - std::unordered_map generated_names_; + std::unordered_map generated_names_; TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); }; diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9f0747a6e2175a968d8f3661ac51512009e86f29 --- /dev/null +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/name_uniquer.h" + +#include +#include +#include + +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class NameUniquerTest : public ::testing::Test {}; + +TEST_F(NameUniquerTest, SimpleUniquer) { + NameUniquer uniquer; + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo__1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo__2", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar", uniquer.GetUniqueName("bar")); + EXPECT_EQ("foo__3", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar__1", uniquer.GetUniqueName("bar")); + EXPECT_EQ("qux", uniquer.GetUniqueName("qux")); +} + +TEST_F(NameUniquerTest, DifferentSeparator) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.2", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar", uniquer.GetUniqueName("bar")); + EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar")); +} + +TEST_F(NameUniquerTest, NumericSuffixes) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); + EXPECT_EQ("foo.55", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.55.1", uniquer.GetUniqueName("foo.55.1")); + EXPECT_EQ("foo.55.2", uniquer.GetUniqueName("foo.55.1")); + EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000")); + EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); + EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); + + // Separator is only recognized in the middle of the prefix. + EXPECT_EQ(".10", uniquer.GetUniqueName(".10")); + EXPECT_EQ(".10.1", uniquer.GetUniqueName(".10")); + EXPECT_EQ("foobar.", uniquer.GetUniqueName("foobar.")); + EXPECT_EQ("foobar..1", uniquer.GetUniqueName("foobar.")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 4f915a0c2eeaca0fe077a907571c8379992185eb..3a1818de82d3fd305e2c6b3bd1f2cf8125806a75 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -84,15 +84,6 @@ PlatformUtil::GetSupportedPlatforms() { return NotFound("no platforms found"); } else if (platforms.size() == 1) { return platforms[0]; - } else if (platforms.size() == 2) { - // In the service we always link the cpu backend for ComputeConstant. So if - // one of the two platforms is CPU then pick the other (non-cpu) platform as - // the default. - if (platforms[0]->id() == se::host::kHostPlatformId) { - return platforms[1]; - } else if (platforms[1]->id() == se::host::kHostPlatformId) { - return platforms[0]; - } } // Multiple platforms present and we can't pick a reasonable default. diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index fe0281a69a441b5462470e88bd3ad73784a8da35..eac573703085aca2801885cd9abbe0022f1c029e 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -36,12 +36,7 @@ class PlatformUtil { // Convenience function which returns the default supported platform. If // exactly one supported platform is present, then this platform is the - // default platform. If exactly two supported platforms are present and one - // platform is CPU (host) then the non-CPU platform is default. This logic is - // used because the XLA service always links in the CPU backend to run - // ComputeConstant, so if exactly one other platform is linked in, we assume - // the intent is to execute on that non-CPU platform. If none of these - // conditions are met the function returns an error. + // default platform. Otherwise returns an error. static StatusOr GetDefaultPlatform(); // Returns a vector of StreamExecutors for the given platform. The vector is diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 404fd3e6d7faaedd6cfba2996034f72d8f116d4e..0fb90230f2f39a841973361f63d17af579a1342b 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -48,23 +48,28 @@ namespace xla { namespace { -// Checks if an instruction can change its shape simply by adjusting metadata. -// This is the case if it is: -// -// - an instruction does not have any producers like Constants -// or Rng instruction, or is a scalar. -// -// Or -// -// - an reshape/transpose instruction with an operand that can trivially change -// its shape. -bool InstructionCanTriviallyChangeShape(const HloInstruction* instruction) { - // Reshape/Transposes are only trivial if their operand is trivial. - if (instruction->opcode() == HloOpcode::kReshape || - instruction->opcode() == HloOpcode::kTranspose) { - CHECK_EQ(instruction->operand_count(), 1); - return InstructionCanTriviallyChangeShape(instruction->operand(0)); - } +bool IsReshapeOrTranspose(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kReshape || + instruction->opcode() == HloOpcode::kTranspose; +} + +// Returns true iff `instruction` can change its shape simply by adjusting +// metadata. +bool CanTriviallyChangeShape(const HloInstruction* instruction) { + // NOTE: Technically a sequence of reshape(reshape(constant)) is also + // trivially reshapable, so we might be tempted to simply recurse if + // IsReshapeOrTranspose(instruction)==true. + // + // But it's not that simple. E.g. reshape(reshape(rng)) is only trivially + // reshapable if *all* instructions in the chain have user_count == 1. And + // reshape(scalar) isn't trivial at all if the reshape itself isn't scalar; we + // rely on implicit scalar broadcast for scalars to be trivial. In addition, + // these cases make it harder to maintain correctness of the UpdateOperand + // logic below. + // + // So don't handle these chains, unless you update the tests and code to deal + // with these properly. One idea is to add a pass immediately beforehand that + // collapses trivial runs of reshapes / transposes. // Scalars can operate with any shape. if (ShapeUtil::IsScalar(instruction->shape())) { @@ -93,9 +98,8 @@ HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( const HloInstruction* hlo) { for (HloInstruction* operand : hlo->operands()) { if (!ShapeUtil::IsScalar(operand->shape()) && - ((operand->opcode() == HloOpcode::kReshape || - operand->opcode() == HloOpcode::kTranspose) && - !InstructionCanTriviallyChangeShape(operand->operand(0)))) { + IsReshapeOrTranspose(operand) && + !CanTriviallyChangeShape(operand->operand(0))) { VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " << hlo->ToStringNoMetadata() << ":\n\t" << operand->ToStringNoMetadata(); @@ -122,28 +126,15 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { } } -// Returns true if an elementwise operation has all operands that can easily -// change shape. Operands can easily change shape if they are all -// reshapes/transposes to and from the same shape. Additionally, operands like -// constant, rng, and any scalar change shape with only an adjustment of -// metadata. -bool IsElementwiseOfEquivalentReshapesOrTransposes( - const HloInstruction* instruction) { - const auto& operands = instruction->operands(); - HloInstruction* first_reshape_operand = - FirstNonScalarAndNonTrivialReshapeOperand(instruction); - // If there are no non-trivial reshapes or transposes, then there is nothing - // to sink below the elementwise operation. - if (!first_reshape_operand) { - return false; - } - VLOG(3) << "** Checking whether instruction is an elementwise operation of " - "equivalent reshapes/transposes: " +// Returns true if all operands of `instruction` can easily change shape. +// Operands can easily change shape if they are all reshapes/transposes to and +// from the same shape. Additionally, operands like constant, rng, and any +// scalar change shape with only an adjustment of metadata. +bool AllOperandsHaveEasyShapeChanges( + const HloInstruction* instruction, + const HloInstruction* first_reshape_operand) { + VLOG(3) << "** Checking whether all operands have easy shape changes: " << instruction->ToStringNoMetadata(); - bool result = (instruction->user_count() > 0 || - instruction == instruction->parent()->root_instruction()) && - instruction->IsElementwise() && !operands.empty(); - // Check whether all operands: // 0. Have the same dimensions as the output -- if not, it may be // implicitly broadcast, which can confound the movement's @@ -155,66 +146,117 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes( // or // 2. Are one of kConstant, kRng, and scalars that can change shape // trivially, - if (result) { - for (auto& operand : operands) { - if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { - VLOG(5) << "Operand shape differs from output shape; may be " - "implicitly broadcast, so preventing " - "movement\n\toperand: " - << operand->ToStringNoMetadata() - << "\n\tinstruction: " << instruction->ToStringNoMetadata(); - result = false; - break; - } - - if (AreEquivalentReshapes(first_reshape_operand, operand)) { - VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " - << first_reshape_operand->ToStringNoMetadata() - << "\n\toperand: " << operand->ToStringNoMetadata(); - continue; - } + for (const HloInstruction* operand : instruction->operands()) { + if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToStringNoMetadata() + << "\n\tinstruction: " << instruction->ToStringNoMetadata(); + return false; + } - if (InstructionCanTriviallyChangeShape(operand)) { - VLOG(5) << "Operand can trivially change shape: " - << operand->ToStringNoMetadata(); - continue; - } + if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\toperand: " << operand->ToStringNoMetadata(); + continue; + } - // TODO(someone): Look into supporting general ops for the operands as - // well. - VLOG(5) << "Operand is neither equalivant to the first Reshape operand" - "nor can trivially change shape: " + if (CanTriviallyChangeShape(operand)) { + VLOG(5) << "Operand can trivially change shape: " << operand->ToStringNoMetadata(); - result = false; - break; + continue; } + + // TODO(someone): Look into supporting general ops for the operands as + // well. + VLOG(5) << "Operand is neither equalivant to the first Reshape operand" + "nor can trivially change shape: " + << operand->ToStringNoMetadata(); + return false; } - VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for " - << instruction->ToStringNoMetadata() << ": " << result; - return result; + VLOG(3) << "All operands have easy shape changes: " + << instruction->ToStringNoMetadata(); + return true; +} + +// This function is called once we've decided to sink reshape/transpose operands +// across an instruction. It returns an updated `operand` with a shape that +// plays nicely with `new_operand_shape`; either it has the same shape (of the +// correct type), or it is a scalar that may be implicitly broadcast. +HloInstruction* UpdateOperand(HloComputation* computation, + const HloInstruction* first_reshape_operand, + const Shape& new_operand_shape, + HloInstruction* operand) { + const PrimitiveType element_type = operand->shape().element_type(); + const Shape new_shape = + ShapeUtil::ChangeElementType(new_operand_shape, element_type); + + switch (operand->opcode()) { + case HloOpcode::kConstant: { + if (first_reshape_operand->opcode() == HloOpcode::kReshape) { + VLOG(5) << "Adding reshape to kConstant operand"; + return computation->AddInstruction( + HloInstruction::CreateReshape(new_shape, operand)); + } else { + CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose); + VLOG(5) << "Adding transpose to kConstant operand"; + std::vector inverse_permutation = + InversePermutation(first_reshape_operand->dimensions()); + return computation->AddInstruction(HloInstruction::CreateTranspose( + new_shape, operand, inverse_permutation)); + } + } + case HloOpcode::kRng: { + CHECK_EQ(operand->user_count(), 1); + VLOG(5) << "Cloning kRng operand with new shape"; + return computation->AddInstruction( + operand->CloneWithNewOperands(new_shape, operand->operands())); + } + case HloOpcode::kReshape: + case HloOpcode::kTranspose: { + VLOG(5) << "Using existing operand of kReshape or kTranspose"; + return operand->mutable_operand(0); + } + default: + LOG(FATAL) << "Unexpected operand opcode during update: " << operand; + } } // Try to sink any reshape or transpose operands of `instruction` across it. We // do so if `instruction` is elementwise and all operands are either equivalent -// reshapes/transposes or are trivially reshapable. Note that no move is -// performend if there is no nontrivial reshapes/transposes. +// reshapes/transposes or are trivially reshapable. StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, HloInstruction* instruction) { - if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { + // Only perform sinks for live elementwise instructions with operands. + const bool is_dead = instruction->user_count() == 0 && + instruction != computation->root_instruction(); + if (!instruction->IsElementwise() || instruction->operands().empty() || + is_dead) { return false; } - HloInstruction* old_reshape = + // Only perform sinks if there are any nontrivial reshape/transpose operands. + const HloInstruction* first_reshape_operand = FirstNonScalarAndNonTrivialReshapeOperand(instruction); - TF_RET_CHECK(old_reshape != nullptr); - Shape new_elementwise_shape = old_reshape->operand(0)->shape(); + if (!first_reshape_operand) { + return false; + } + + // Only perform sinks if all operands can easily change shape. + if (!AllOperandsHaveEasyShapeChanges(instruction, first_reshape_operand)) { + return false; + } - VLOG(3) << "** Trying to sink reshape or transpose: " - << instruction->ToStringNoMetadata() - << "\n\told reshape: " << old_reshape->ToStringNoMetadata() - << "\n\tnew elementwise shape: " - << ShapeUtil::HumanString(new_elementwise_shape); + // At this point we've decided to sink reshape/transpose operands. + const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape(); + VLOG(3) << "** Sinking reshape or transpose: " + << instruction->ToStringNoMetadata() << "\n\tfirst reshape operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\tnew operand shape: " + << ShapeUtil::HumanString(new_operand_shape); auto operands = instruction->operands(); for (size_t i = 0; i < operands.size(); ++i) { @@ -224,55 +266,19 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, if (ShapeUtil::IsScalar(operands[i]->shape())) { continue; } - PrimitiveType element_type = operands[i]->shape().element_type(); - switch (operands[i]->opcode()) { - case HloOpcode::kConstant: { - if (old_reshape->opcode() == HloOpcode::kReshape) { - VLOG(3) << "Creating reshape for kConstant operand " << i << ": " - << operands[i]->ToStringNoMetadata(); - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateReshape( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i])); - } else { - TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose); - std::vector inverse_permutation = - InversePermutation(old_reshape->dimensions()); - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateTranspose( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i], inverse_permutation)); - } - break; - } - case HloOpcode::kRng: { - CHECK_EQ(operands[i]->user_count(), 1); - operands[i] = instruction->parent()->AddInstruction( - operands[i]->CloneWithNewOperands( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i]->operands())); - break; - } - case HloOpcode::kReshape: - case HloOpcode::kTranspose: - operands[i] = operands[i]->mutable_operand(0); - break; - default: - LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " - "transposes."; - } + VLOG(3) << "Updating operand #" << i << ": " + << operands[i]->ToStringNoMetadata(); + operands[i] = UpdateOperand(computation, first_reshape_operand, + new_operand_shape, operands[i]); } if (HloOpcode::kFusion == instruction->opcode()) { // Here we already know `instruction` is elementwise, and no operand is - // implicit broadcast as if it were the operands would not be equivalent - // reshapes, so all the fused instructions have the same dimensions. + // implicit broadcast as if it were the operands would not have easy shape + // changes, so all the fused instructions have the same dimensions. for (const auto& fused_instruction : instruction->fused_instructions()) { Shape* shape = fused_instruction->mutable_shape(); - *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); - *shape->mutable_layout() = new_elementwise_shape.layout(); + *shape->mutable_dimensions() = new_operand_shape.dimensions(); + *shape->mutable_layout() = new_operand_shape.layout(); } } HloInstruction* new_elementwise = @@ -284,12 +290,12 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, // // In this case, convert' should have the same element type as // `convert` and the same dimensions as operands[0]. - ShapeUtil::ChangeElementType(new_elementwise_shape, + ShapeUtil::ChangeElementType(new_operand_shape, instruction->shape().element_type()), operands)); std::unique_ptr new_reshape; - switch (old_reshape->opcode()) { + switch (first_reshape_operand->opcode()) { case HloOpcode::kReshape: VLOG(3) << "Creating new reshape for new elementwise op: " << new_elementwise->ToStringNoMetadata(); @@ -297,8 +303,9 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, HloInstruction::CreateReshape(instruction->shape(), new_elementwise); break; case HloOpcode::kTranspose: - new_reshape = HloInstruction::CreateTranspose( - instruction->shape(), new_elementwise, old_reshape->dimensions()); + new_reshape = + HloInstruction::CreateTranspose(instruction->shape(), new_elementwise, + first_reshape_operand->dimensions()); break; default: LOG(FATAL) << "Bad opcode"; @@ -312,6 +319,8 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, StatusOr ReshapeMover::Run(HloModule* module) { bool changed = false; + VLOG(2) << "Pre ReshapeMover HLO:"; + XLA_VLOG_LINES(2, module->ToString()); for (auto* comp : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool did_change, @@ -319,6 +328,8 @@ StatusOr ReshapeMover::Run(HloModule* module) { changed |= did_change; } } + VLOG(2) << "Post ReshapeMover HLO:"; + XLA_VLOG_LINES(2, module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index b7e0a46939a10b3376758109214c9722976f50e0..1f59e3b3147facb6f2fae00d6c810bf54d560e5c 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -26,7 +26,7 @@ namespace xla { // them inputward also. class ReshapeMover : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "reshape-motion"; } + tensorflow::StringPiece name() const override { return "reshape-mover"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index a81d3f4eb344510b3973aa46a576d11f613bc404..aac8638a54f744f0c230ec6c5ca071c1daf45ab2 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using ReshapeMoverTest = HloTestBase; +using ReshapeMoverTest = HloVerifiedTestBase; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { HloComputation::Builder builder(TestName()); @@ -50,13 +50,12 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); @@ -89,13 +88,12 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); @@ -115,13 +113,12 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -142,12 +139,11 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, param1))); @@ -193,21 +189,19 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param2)); builder.AddInstruction(HloInstruction::CreateTernary( - ShapeUtil::MakeShape(PRED, {2, 3}), HloOpcode::kSelect, const0, reshape1, - reshape2)); + root_shape, HloOpcode::kSelect, const0, reshape1, reshape2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(const0, reshape1, reshape2)); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(op::Reshape(const0), param1, param2))); - EXPECT_EQ(const0->shape().DebugString(), + EXPECT_EQ(root_shape.DebugString(), computation->root_instruction()->shape().DebugString()); } @@ -228,17 +222,16 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); - auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, root_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, param1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); @@ -260,7 +253,7 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { // trivial reshapes. TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { HloComputation::Builder builder(TestName()); - auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); + auto root_shape = ShapeUtil::MakeShape(F32, {3, 2}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = @@ -272,18 +265,17 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); auto pred = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(PRED, {1, 3, 1, 2}), "pred")); + 0, ShapeUtil::MakeShape(PRED, {3, 2}), "pred")); builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, pred, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); @@ -323,13 +315,12 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), const1)); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, op::Reshape(const1)))); @@ -337,6 +328,48 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { computation->root_instruction()->shape().DebugString()); } +// For a graph that looks like: +// +// +- reshape0 - param0 (shape A) +// | +// +- reshape1 - const1 (shape B) +// | +// add +// +// There is 1 non-trivial reshape (reshape0). It's not clear whether reshape1 +// should be trivial or not; conceptually it's trivial, but handling it would +// complicate the rest of our logic. +// +// For now we treat it as non-trivial, so we verify that we don't sink the +// reshapes in this case. +TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 3}), "param0")); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({9, 8, 7}))); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); + + builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape0, reshape1)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(const1))); + + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(const1))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); +} + TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); @@ -351,15 +384,14 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Fusion(param0, param1))); @@ -386,14 +418,13 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(pred, param0, param1))); @@ -416,12 +447,11 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, param0, param1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); @@ -468,12 +498,11 @@ TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( constant->shape(), HloOpcode::kMultiply, constant, reshape)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Constant(), op::Reshape(param0))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Constant(), op::Reshape(param0))); @@ -517,15 +546,14 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, reshape2, reshape3)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Add(op::Reshape(param2), op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1))))); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index bd7898a41fa9e36630f32e89cc56665b42c2fbc0..0fbc2f2fec64917f5117dc5021c5e0a5b0f4367e 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -187,8 +187,9 @@ tensorflow::Status Service::Computation(const ComputationRequest* arg, *result->mutable_computation() = computation_tracker_.NewComputation(arg->name()); - VLOG(1) << Printf("Created new computation %s on service %p", - result->computation().ShortDebugString().c_str(), this); + VLOG(1) << Printf("Created new computation %s on service %p, name %s", + result->computation().ShortDebugString().c_str(), this, + arg->name().c_str()); return tensorflow::Status::OK(); } @@ -337,7 +338,7 @@ StatusOr>> Service::BuildExecutables( std::vector versioned_handles, std::vector> module_configs, Backend* backend, - std::vector executors) { + std::vector> executors) { VLOG(1) << Printf("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. @@ -614,31 +615,41 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); std::vector> all_arguments; - std::vector executors; + std::vector> all_executors; std::vector versioned_handles; std::vector> module_configs; std::vector computation_names; std::vector device_handles; - if (arg->requests_size() * options_.number_of_replicas() > + int num_requested_devices = + std::accumulate(arg->requests().begin(), arg->requests().end(), 0, + [](int a, const ExecuteRequest& r) -> int { + return a + r.execution_options().device_handles_size(); + }); + if (num_requested_devices * options_.number_of_replicas() > execute_backend_->device_count()) { return FailedPrecondition( "there are not enough stream executors to execute %d computations", - arg->requests_size()); + num_requested_devices); } for (int64 i = 0; i < arg->requests_size(); ++i) { // Get the stream executor for the i'th computation. This stream executor // is one of the executors to run the replicated computation. - if (!arg->requests(i).has_device_handle()) { + const ExecutionOptions& execution_options = + arg->requests(i).execution_options(); + if (execution_options.device_handles().empty()) { return FailedPrecondition( "device handles must be given to execute parallel computations"); } - TF_ASSIGN_OR_RETURN( - auto replicas, - Replicas(*execute_backend_, arg->requests(i).device_handle())); - se::StreamExecutor* executor = replicas[0]; - CHECK(executor != nullptr); + std::vector executors; + for (const auto& device_handle : execution_options.device_handles()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, device_handle)); + se::StreamExecutor* executor = replicas[0]; + CHECK(executor != nullptr); + executors.push_back(executor); + } // Resolve the UserComputation object associated with the requested // computation and compute the program shape. @@ -657,10 +668,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. + // In the case of partitioned computations, assume all arguments go on the + // zeroth core. TF_ASSIGN_OR_RETURN( std::vector arg_allocations, ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), - executor->device_ordinal())); + executors[0]->device_ordinal())); std::vector arguments; arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { @@ -677,11 +690,15 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(arguments); + all_arguments.insert(all_arguments.end(), executors.size() - 1, {}); versioned_handles.push_back(versioned_handle); module_configs.push_back(std::move(module_config)); - computation_names.push_back(user_computation->name()); - executors.push_back(executor); - device_handles.push_back(arg->requests(i).device_handle()); + computation_names.insert(computation_names.end(), executors.size(), + user_computation->name()); + all_executors.push_back(executors); + device_handles.insert(device_handles.end(), + execution_options.device_handles().begin(), + execution_options.device_handles().end()); } // Build the user computations into HloModules and compile to generate the @@ -689,7 +706,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, TF_ASSIGN_OR_RETURN( std::vector> executables, BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), executors)); + execute_backend_.get(), all_executors)); std::vector executable_ptrs; executable_ptrs.reserve(executables.size()); for (const auto& executable : executables) { @@ -751,6 +768,17 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, return InvalidArgument("computations may not be empty"); } + // If we received multiple device handles, we must partition the module. + if (arg->execution_options().device_handles_size() > 1) { + ExecuteParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); + TF_RET_CHECK(parallel_result.responses_size() > 0); + *result = parallel_result.responses(0); + return Status::OK(); + } + TF_ASSIGN_OR_RETURN( std::shared_ptr program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index f96f18f072ad728b4612381a03549e9fd3110574..2452259f736054b5bf1f03fc5103d65eded7f398 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -293,7 +293,7 @@ class Service : public ServiceInterface { std::vector versioned_handles, std::vector> module_configs, Backend* backend, - std::vector executors); + std::vector> executors); // Similar to BuildExecutable, but look in the compilation cache for the // executable first. If the executable is not in the cache, it is built and diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index ffd80188274e17b4345c1d9d4f47ea04195798b2..0458932a730d2c82f256f01b5693e443245e7901 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -53,14 +53,18 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { return UNOP_EXP; case HloOpcode::kFloor: return UNOP_FLOOR; + case HloOpcode::kImag: + return UNOP_IMAG; case HloOpcode::kIsFinite: return UNOP_IS_FINITE; case HloOpcode::kLog: return UNOP_LOG; - case HloOpcode::kLogicalNot: - return UNOP_LOGICAL_NOT; + case HloOpcode::kNot: + return UNOP_NOT; case HloOpcode::kNegate: return UNOP_NEGATE; + case HloOpcode::kReal: + return UNOP_REAL; case HloOpcode::kRoundNearestAfz: return UNOP_ROUND_NEAREST_AFZ; case HloOpcode::kSign: @@ -81,6 +85,10 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { // opcode. BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { switch (opcode) { + case HloOpcode::kAtan2: + return BINOP_ATAN2; + case HloOpcode::kComplex: + return BINOP_COMPLEX; case HloOpcode::kDot: return BINOP_DOT; case HloOpcode::kMultiply: @@ -113,10 +121,16 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_POW; case HloOpcode::kRemainder: return BINOP_REM; - case HloOpcode::kLogicalOr: - return BINOP_LOGICAL_OR; - case HloOpcode::kLogicalAnd: - return BINOP_LOGICAL_AND; + case HloOpcode::kOr: + return BINOP_OR; + case HloOpcode::kAnd: + return BINOP_AND; + case HloOpcode::kShiftLeft: + return BINOP_SHIFT_LEFT; + case HloOpcode::kShiftRightArithmetic: + return BINOP_SHIFT_RIGHT_ARITHMETIC; + case HloOpcode::kShiftRightLogical: + return BINOP_SHIFT_RIGHT_LOGICAL; default: LOG(FATAL) << "unhandled opcode " << opcode; } @@ -130,8 +144,6 @@ TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) { return TRIOP_CLAMP; case HloOpcode::kSelect: return TRIOP_SELECT; - case HloOpcode::kUpdate: - return TRIOP_UPDATE; default: LOG(FATAL) << "unhandled opcode " << opcode; } @@ -303,30 +315,53 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, switch (operation) { case UNOP_FLOOR: case UNOP_CEIL: + if (!ShapeUtil::ElementIsFloating(arg)) { + return InvalidArgument( + "expected element type in shape to be floating for floor/ceil " + "operation; got %s", + PrimitiveType_Name(arg.element_type()).c_str()); + } + return arg; case UNOP_COS: case UNOP_SIN: case UNOP_EXP: case UNOP_LOG: case UNOP_TANH: - if (!ShapeUtil::ElementIsFloating(arg)) { + if (!ShapeUtil::ElementIsFloating(arg) && + !ShapeUtil::ElementIsComplex(arg)) { return InvalidArgument( - "expected element type in shape to be floating for exp/log/tanh " - "operation; got %s", + "expected element type in shape to be floating or complex for " + "sin/cos/exp/log/tanh operation; got %s", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; + case UNOP_REAL: + case UNOP_IMAG: + if (!ShapeUtil::ElementIsComplex(arg)) { + return InvalidArgument( + "expected element type in shape to be complex for real/imag " + "operation; got %s", + PrimitiveType_Name(arg.element_type()).c_str()); + } + return ShapeUtil::ChangeElementType(arg, F32); case UNOP_ABS: + if (ShapeUtil::ElementIsComplex(arg)) { + return ShapeUtil::ChangeElementType( + arg, primitive_util::ComplexComponentType(arg.element_type())); + } + return arg; case UNOP_NEGATE: case UNOP_ROUND_NEAREST_AFZ: case UNOP_SIGN: case UNOP_SORT: return arg; - case UNOP_LOGICAL_NOT: - if (arg.element_type() != PRED) { + case UNOP_NOT: + if (arg.element_type() != PRED && + !primitive_util::IsIntegralType(arg.element_type())) { return InvalidArgument( - "expected pred element type in argument to logical-not operation; " - "got %s", + "expected pred or an integral element type in argument to not " + "operation; got %s", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; @@ -457,7 +492,10 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) { return InvalidArgument( - "the rank of the operand and the padding configuration do not match."); + "The rank of the operand and the padding configuration do not match: " + "%s vs %s", + ShapeUtil::HumanString(operand_shape).c_str(), + padding_config.ShortDebugString().c_str()); } if (operand_shape.element_type() != padding_value_shape.element_type()) { return InvalidArgument( @@ -743,24 +781,44 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case BINOP_MIN: case BINOP_SUB: case BINOP_ADD: + case BINOP_ATAN2: case BINOP_POW: case BINOP_DIV: case BINOP_REM: case BINOP_MUL: + case BINOP_SHIFT_LEFT: + case BINOP_SHIFT_RIGHT_ARITHMETIC: + case BINOP_SHIFT_RIGHT_LOGICAL: return InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions); - case BINOP_LOGICAL_AND: - case BINOP_LOGICAL_OR: - if (lhs.element_type() != PRED) { + case BINOP_COMPLEX: { + if (!ShapeUtil::ElementIsFloating(lhs)) { return InvalidArgument( - "expected pred element type in argument to logical and/or " + "expected element type in shape to be floating for complex compose " "operation; got %s", PrimitiveType_Name(lhs.element_type()).c_str()); } + TF_ASSIGN_OR_RETURN(const Shape& shape, + InferElementwiseBinaryOpShape(operation, lhs, rhs, + broadcast_dimensions)); + if (lhs.element_type() == F32) { + return ShapeUtil::ChangeElementType(shape, C64); + } else { + return Unimplemented("complex component type not supported"); + } + } + case BINOP_AND: + case BINOP_OR: + if (lhs.element_type() != PRED && + !primitive_util::IsIntegralType(lhs.element_type())) { + return InvalidArgument( + "expected pred or integral type in argument to and/or operation; " + "got %s", + PrimitiveType_Name(lhs.element_type()).c_str()); + } return InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions); - case BINOP_EQ: case BINOP_GE: case BINOP_GT: @@ -809,14 +867,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InferClampShape(lhs, rhs, ehs); case TRIOP_SELECT: return InferSelectShape(lhs, rhs, ehs); - case TRIOP_UPDATE: - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation")); - return lhs; default: return InvalidArgument("unknown operation %s", TernaryOperation_Name(operation).c_str()); @@ -1372,14 +1422,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "Window: %s", window.DebugString().c_str()); } - int num_spatial_dims = dnums.spatial_dimensions_size(); - if (num_spatial_dims < 1) { - return InvalidArgument( - "Convolution requires at least one spatial dimension.\n" - "Window: %s", - window.DebugString().c_str()); - } + const int num_spatial_dims = dnums.spatial_dimensions_size(); if (window.dimensions_size() != num_spatial_dims) { return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" @@ -1387,7 +1431,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( window.DebugString().c_str(), dnums.DebugString().c_str()); } - int num_dims = num_spatial_dims + 2; + const int num_dims = num_spatial_dims + 2; if (ShapeUtil::Rank(lhs) != num_dims) { return InvalidArgument( "The LHS argument to a convolution should have rank %d.\n" @@ -1406,8 +1450,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Verifies that the input and window dimensions are a permutation of // the dimension numbers. std::vector input_dnums(num_dims); - input_dnums[0] = dnums.batch_dimension(); - input_dnums[1] = dnums.feature_dimension(); + input_dnums[0] = dnums.input_batch_dimension(); + input_dnums[1] = dnums.input_feature_dimension(); std::copy(dnums.spatial_dimensions().begin(), dnums.spatial_dimensions().end(), input_dnums.begin() + 2); std::sort(input_dnums.begin(), input_dnums.end()); @@ -1447,8 +1491,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int i = 0; i < num_spatial_dims; ++i) { input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i)); } - const int64 input_features = lhs.dimensions(dnums.feature_dimension()); - const int64 input_batch = lhs.dimensions(dnums.batch_dimension()); + const int64 input_features = lhs.dimensions(dnums.input_feature_dimension()); + const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension()); std::vector kernel_spatial_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -1490,8 +1534,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /*allow_negative_padding=*/true)); std::vector dimensions(num_dims); - dimensions[dnums.batch_dimension()] = input_batch; - dimensions[dnums.feature_dimension()] = kernel_output_features; + dimensions[dnums.output_batch_dimension()] = input_batch; + dimensions[dnums.output_feature_dimension()] = kernel_output_features; for (int i = 0; i < num_spatial_dims; ++i) { dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i); } @@ -1894,11 +1938,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( Shape inferred_shape = ShapeUtil::MakeShape(operand.element_type(), new_sizes); + VLOG(3) << "Reshape inferred shape: " + << ShapeUtil::HumanString(inferred_shape); if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "reshape operation has mismatched element counts: from=%lld to=%lld", - ShapeUtil::ElementsIn(operand), ShapeUtil::ElementsIn(inferred_shape)); + "reshape operation has mismatched element counts: from=%lld (%s) " + "to=%lld (%s)", + ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), + ShapeUtil::ElementsIn(inferred_shape), + ShapeUtil::HumanString(inferred_shape).c_str()); } std::vector indices(ShapeUtil::Rank(operand)); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 7c9c7e8d6ac32d868ad18d892b3f8e9dc18b8259..d12f7bd1453890db3280e54719a6ce811006336d 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test { // Some handy scalar shapes. const Shape s32_ = ShapeUtil::MakeShape(S32, {}); const Shape f32_ = ShapeUtil::MakeShape(F32, {}); + const Shape f64_ = ShapeUtil::MakeShape(F64, {}); const Shape pred_ = ShapeUtil::MakeShape(PRED, {}); // Some handy vector and matrix shapes of F32 type. @@ -251,6 +252,44 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) { .ok()); } +TEST_F(ShapeInferenceTest, Complex) { + auto complex_shape = [&](const Shape& lhs, const Shape& rhs, + const tensorflow::gtl::ArraySlice& bcast) { + return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX, + lhs, rhs, bcast); + }; + // Inputs must be FP. + ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok()); + ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok()); + // Component types must match. + ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok()); + // Only F32->C64 supported. + ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok()); + // Validate correct uses. + Shape c64_32 = ShapeUtil::MakeShape(C64, {32}); + TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {}))); + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); + + Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64}); + TF_ASSERT_OK_AND_ASSIGN(result, + complex_shape(vector_64_, matrix_32_64_, {1})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); + TF_ASSERT_OK_AND_ASSIGN(result, + complex_shape(matrix_32_64_, vector_64_, {1})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); + TF_ASSERT_OK_AND_ASSIGN(result, + complex_shape(matrix_32_64_, matrix_32_64_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); +} + TEST_F(ShapeInferenceTest, VariadicOpTuplify) { StatusOr result = ShapeInference::InferVariadicOpShape( VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); @@ -352,8 +391,10 @@ TEST_F(ShapeInferenceTest, Convolve) { // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); @@ -392,8 +433,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4}); - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); @@ -433,8 +476,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); @@ -475,8 +520,10 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2}); ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(3); - dnums.set_feature_dimension(2); + dnums.set_input_batch_dimension(3); + dnums.set_output_batch_dimension(3); + dnums.set_input_feature_dimension(2); + dnums.set_output_feature_dimension(2); dnums.add_spatial_dimensions(0); dnums.add_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0 diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 816c8a7485bb9c5c12d3dc9e17404c74460113f5..8c2640adf52f10c387e7a9c09c0d73a09c054919 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -58,14 +58,32 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( return {}; } - // We only support folding the RHS. - const int64 kRhsOperandIndex = 1; - auto& operand = *convolution.operand(kRhsOperandIndex); - if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) { - return transposable_conv_operands(convolution, {kRhsOperandIndex}); + const ConvolutionDimensionNumbers& dnums = + convolution.convolution_dimension_numbers(); + + TransposeFolding::OperandIndices operand_set; + for (int64 i = 0; i < convolution.operand_count(); ++i) { + auto& operand = *convolution.operand(i); + if (operand.opcode() == HloOpcode::kTranspose && + operand.user_count() == 1) { + const auto& transpose_dimensions = operand.dimensions(); + // We can transpose the LHS so long as it doesn't move around spatial + // dimensions because ConvolutionDimensionNumbers doesn't have different + // fields for input and output spatial dimensions. + if (i == 0 && + std::any_of(dnums.spatial_dimensions().begin(), + dnums.spatial_dimensions().end(), + [&](const int64 spatial_dimension) { + return transpose_dimensions[spatial_dimension] != + spatial_dimension; + })) { + continue; + } + operand_set.push_back(i); + } } - return {}; + return transposable_conv_operands(convolution, operand_set); } using InstructionOperandsPair = @@ -98,40 +116,61 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair) { // Returns whether the module is changed. bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; - - // We only support fusing the RHS transpose into convolution. - // - // ConvolutionDimensionNumbers doesn't make enough of a distinction between - // the output and the activations. - // - // TODO(b/37125184): Support transposing the LHS too. - if (pair.second.size() != 1 || pair.second.front() != 1) { - return false; - } + auto& operand_indices = pair.second; const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); - HloInstruction& transpose = *convolution.mutable_operand(1); - CHECK_EQ(transpose.opcode(), HloOpcode::kTranspose); - const auto& transpose_dimensions = transpose.dimensions(); - HloInstruction& transpose_operand = *transpose.mutable_operand(0); - - // Everything remains the same except for the kernel dimension numbers. We - // need to apply the transpose permutation to the original shape to figure out - // what the new logical dimensions are. ConvolutionDimensionNumbers new_dnums = dnums; - new_dnums.set_kernel_input_feature_dimension( - transpose_dimensions[dnums.kernel_input_feature_dimension()]); - new_dnums.set_kernel_output_feature_dimension( - transpose_dimensions[dnums.kernel_output_feature_dimension()]); - for (auto& kernel_spatial_dimension : - *new_dnums.mutable_kernel_spatial_dimensions()) { - kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; + + HloInstruction* new_lhs; + const int64 kLhsIdx = 0; + if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) != + operand_indices.end()) { + HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx); + const auto& transpose_dimensions = transpose.dimensions(); + HloInstruction& transpose_operand = *transpose.mutable_operand(0); + + // Everything remains the same except for the input/output dimension + // numbers. We need to apply the transpose permutation to the original shape + // to figure out what the new logical dimensions are. + new_dnums.set_input_batch_dimension( + transpose_dimensions[dnums.input_batch_dimension()]); + new_dnums.set_input_feature_dimension( + transpose_dimensions[dnums.input_feature_dimension()]); + for (const auto& spatial_dimension : dnums.spatial_dimensions()) { + CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]); + } + new_lhs = &transpose_operand; + } else { + new_lhs = convolution.mutable_operand(kLhsIdx); + } + + HloInstruction* new_rhs; + const int64 kRhsIdx = 1; + if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) != + operand_indices.end()) { + HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx); + const auto& transpose_dimensions = transpose.dimensions(); + HloInstruction& transpose_operand = *transpose.mutable_operand(0); + + // Everything remains the same except for the kernel dimension numbers. We + // need to apply the transpose permutation to the original shape to figure + // out what the new logical dimensions are. + new_dnums.set_kernel_input_feature_dimension( + transpose_dimensions[dnums.kernel_input_feature_dimension()]); + new_dnums.set_kernel_output_feature_dimension( + transpose_dimensions[dnums.kernel_output_feature_dimension()]); + for (auto& kernel_spatial_dimension : + *new_dnums.mutable_kernel_spatial_dimensions()) { + kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; + } + new_rhs = &transpose_operand; + } else { + new_rhs = convolution.mutable_operand(kRhsIdx); } auto new_conv = HloInstruction::CreateConvolve( - convolution.shape(), convolution.mutable_operand(0), &transpose_operand, - convolution.window(), new_dnums); + convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index a6161b46460068b83fa3f0762e49a10a83b1471c..00462f9be1e9beb2f2694060ebfaa70b0b9dd4a0 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -313,8 +313,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1)); } -// Test that a transpose of the activations does not get folded into -// convolution. +// Test that a transpose of the activations gets folded into convolution. TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { auto builder = HloComputation::Builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -348,18 +347,25 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { module.AddEntryComputation(builder.Build(conv)); FoldTranspose(&module); - // Instructions after folding: transpose_x, y, and the convolution. + // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( entry_computation->instructions().begin(), entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(transpose_x)) - << "transpose_x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(conv)) - << "transpose_x is not in entry_computation."; - CHECK_EQ(0, instruction_set.size()) - << "entry_computation should contain exactly 4 instructions."; + EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + EXPECT_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.input_feature_dimension(), + new_conv->convolution_dimension_numbers().input_batch_dimension()); + EXPECT_EQ( + dnums.input_batch_dimension(), + new_conv->convolution_dimension_numbers().input_feature_dimension()); + EXPECT_EQ(dnums.spatial_dimensions(0), + new_conv->convolution_dimension_numbers().spatial_dimensions(0)); + EXPECT_EQ(dnums.spatial_dimensions(1), + new_conv->convolution_dimension_numbers().spatial_dimensions(1)); } } // namespace diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index be457329521c62dc86d60b09cf189c43e6f1dde1..30dabb56bdbcab83aad226bf09e6cfcf015d215d 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -44,7 +44,7 @@ namespace xla { // A class describing the source(s) of the Buffer(s) contained in the output of // a particular HLO instruction. The structure of PointsToSet mirrors the -// structure of the instruction's shape which may be an arbitrary tree (eg, a +// structure of the instruction's shape, which may be an arbitrary tree (eg, a // nested tuple). Each node in this tree corresponds to a single buffer in the // instruction's output and contains the set of Buffers which might define // the corresponding buffer. @@ -148,7 +148,7 @@ class PointsToSet { ShapeTree tree_; // PointsToSet contains references (const LogicalBuffer*) to elements within - // TuplePointsToAnalysis so disable copying. + // TuplePointsToAnalysis, so disable copying. TF_DISALLOW_COPY_AND_ASSIGN(PointsToSet); }; diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 3f62501bb5c46f397b5cd96688e50a43f8e83428..adf7972e0d025e178981e16719ffa7edfab0269e 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -54,14 +55,18 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kExp; case UNOP_FLOOR: return HloOpcode::kFloor; + case UNOP_IMAG: + return HloOpcode::kImag; case UNOP_IS_FINITE: return HloOpcode::kIsFinite; case UNOP_LOG: return HloOpcode::kLog; - case UNOP_LOGICAL_NOT: - return HloOpcode::kLogicalNot; + case UNOP_NOT: + return HloOpcode::kNot; case UNOP_NEGATE: return HloOpcode::kNegate; + case UNOP_REAL: + return HloOpcode::kReal; case UNOP_ROUND_NEAREST_AFZ: return HloOpcode::kRoundNearestAfz; case UNOP_SIGN: @@ -79,6 +84,10 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { switch (binop) { + case BINOP_ATAN2: + return HloOpcode::kAtan2; + case BINOP_COMPLEX: + return HloOpcode::kComplex; case BINOP_DOT: return HloOpcode::kDot; case BINOP_MUL: @@ -111,10 +120,16 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kPower; case BINOP_REM: return HloOpcode::kRemainder; - case BINOP_LOGICAL_OR: - return HloOpcode::kLogicalOr; - case BINOP_LOGICAL_AND: - return HloOpcode::kLogicalAnd; + case BINOP_OR: + return HloOpcode::kOr; + case BINOP_AND: + return HloOpcode::kAnd; + case BINOP_SHIFT_LEFT: + return HloOpcode::kShiftLeft; + case BINOP_SHIFT_RIGHT_ARITHMETIC: + return HloOpcode::kShiftRightArithmetic; + case BINOP_SHIFT_RIGHT_LOGICAL: + return HloOpcode::kShiftRightLogical; default: LOG(FATAL) << "unhandled operation " << binop; } @@ -126,8 +141,6 @@ HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) { return HloOpcode::kClamp; case TRIOP_SELECT: return HloOpcode::kSelect; - case TRIOP_UPDATE: - return HloOpcode::kUpdate; default: LOG(FATAL) << "unhandled operation " << triop; } @@ -1837,10 +1850,17 @@ UserComputation::GetEmbeddedComputations( XLA_VLOG_LINES(3, session_computation_.DebugString()); std::vector computations; + std::vector sorted_handles; for (const auto& handle_request : session_computation_.requests()) { - int64 handle_value = handle_request.first; + sorted_handles.push_back(handle_request.first); + } + std::sort(sorted_handles.begin(), sorted_handles.end()); + for (int64 handle : sorted_handles) { + const auto& handle_request = session_computation_.requests().find(handle); + CHECK(handle_request != session_computation_.requests().end()); + int64 handle_value = handle_request->first; if (handle_value <= version) { - const OperationRequest& request = handle_request.second; + const OperationRequest& request = handle_request->second; switch (request.request().op_case()) { case OpRequest::kCallRequest: { CHECK_EQ(1, request.embedded_computation_versions_size()); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 8e16056b239a9e1d1776bfe91f6e36862e0feeec..b5eb81dfc6a4117909dcb18fdbe61443b1a1eb95 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -102,6 +102,32 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return true; } +// Constructs and returns the new shape with the given minor_to_major order in +// its Layout. +StatusOr MakeShapeWithLayoutInternal( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice minor_to_major) { + if (dimensions.size() != minor_to_major.size()) { + return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", + dimensions.size(), minor_to_major.size()); + } + if (element_type == OPAQUE || element_type == TUPLE) { + return InvalidArgument("Unsupported element type: %s", + PrimitiveType_Name(element_type).c_str()); + } + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); + min2maj->Clear(); + for (int64 value : minor_to_major) { + min2maj->Add(value); + } + if (!shape.has_layout()) { + return InvalidArgument("Shape has no layout."); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); + return shape; +} + } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { @@ -152,16 +178,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { /* static */ Shape ShapeUtil::MakeShapeWithLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice minor_to_major) { - CHECK_EQ(dimensions.size(), minor_to_major.size()); - Shape shape = MakeShape(element_type, dimensions); - auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); - for (int64 value : minor_to_major) { - min2maj->Add(value); - } - DCHECK(shape.has_layout()); - TF_DCHECK_OK(ValidateShape(shape)); - return shape; + return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) + .ValueOrDie(); } /* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( @@ -254,6 +272,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { case U16: case U32: case U64: + case C64: case TUPLE: case OPAQUE: return false; @@ -263,6 +282,10 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } } +/* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) { + return primitive_util::IsComplexType(shape.element_type()); +} + /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) { return primitive_util::IsFloatingPointType(shape.element_type()); } @@ -499,11 +522,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the layout minor-to-major and set it. TF_ASSIGN_OR_RETURN(std::vector min2maj, comma_list_to_int64s(layout_string)); - TF_RET_CHECK(dimensions.size() == min2maj.size()); - result = - ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj); + TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( + primitive_type, dimensions, min2maj)); } - TF_DCHECK_OK(ShapeUtil::ValidateShape(result)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); return std::move(result); } @@ -575,6 +597,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(float); case F64: return sizeof(double); + case C64: + return sizeof(complex64); default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index c5800acaf11a99be4545e2ad4330101e7971bd7c..8f8d4a73c9ecb3f4236f3877323ad1127bb0b9c2 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -291,6 +291,9 @@ class ShapeUtil { // Returns whether the element type of the shape is floating point. static bool ElementIsFloating(const Shape& shape); + // Returns whether the element type of the shape is complex. + static bool ElementIsComplex(const Shape& shape); + // Returns whether the element type has the given bit width. static bool ElementHasBitWidth(const Shape& shape, int bits); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 79945b9c77299b7006d014aed4507566e3c2c750..0ba542ad1bec290c35c52a8dd5177893770310fd 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -218,6 +218,10 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(F64)); EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {}))); EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {10, 20}))); + + EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); + EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); + EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); } TEST(ShapeUtilTest, ByteSizeOfWithPadding) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e45b839afd2a9666215744f904dfbed5eca0a41b..4e1be24b61cc436b0baf62cc6e28ad8d13fe71ac 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -23,7 +23,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") @@ -102,28 +101,34 @@ cc_library( deps = [ ":literal_test_util", "//tensorflow/compiler/xla:shape_layout", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:backend", - "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_layout", - "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:hlo_graph_dumper", - "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", - "//third_party/eigen3", + ], +) + +cc_library( + name = "hlo_verified_test_base", + testonly = True, + srcs = ["hlo_verified_test_base.cc"], + hdrs = ["hlo_verified_test_base.h"], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -373,6 +378,7 @@ xla_test( name = "params_test", srcs = ["params_test.cc"], shard_count = 30, + tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -914,6 +920,7 @@ xla_test( name = "reduce_window_test", timeout = "long", srcs = [], + tags = ["optonly"], xla_test_library_deps = [":reduce_window_test_library"], deps = [], ) @@ -981,13 +988,13 @@ xla_test( xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], - linkopts = export_dynamic_linkopts, deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1394,8 +1401,10 @@ xla_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "//third_party/eigen3", ], ) @@ -1461,6 +1470,7 @@ xla_test( xla_test( name = "local_client_execute_test", srcs = ["local_client_execute_test.cc"], + tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 24bccf686349e60f2a4fc4b3f2f6ef836f07d107..a62b13e04ff35b06846039d7665dfc8e4205eec2 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -496,58 +496,315 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalAnd) { +XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.LogicalAnd(a, b); + auto out = builder.And(a, b); ComputeAndCompareR1(&builder, {false, false, false, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalAndZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{false, false}, {true, true}}); + auto b = builder.ConstantR2({{false, true}, {false, true}}); + auto out = builder.And(a, b); + + Array2D expected_array({{false, false}, {false, true}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.LogicalAnd(a, b); + auto out = builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalOr) { +XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, -1, -8}); + auto b = builder.ConstantR1({5, -7, 12}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {0, -7, 8}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, -5}, {-1, 5}}); + auto b = builder.ConstantR2({{1, -6}, {4, 5}}); + auto out = builder.And(a, b); + + Array2D expected_array({{0, -6}, {4, 5}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, 1, 8}); + auto b = builder.ConstantR1({5, 7, 12}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {0, 1, 8}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, 1}, {3, 8}}); + auto b = builder.ConstantR2({{1, 0}, {7, 6}}); + auto out = builder.And(a, b); + + Array2D expected_array({{0, 0}, {3, 0}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.LogicalOr(a, b); + auto out = builder.Or(a, b); ComputeAndCompareR1(&builder, {false, true, true, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalOrZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{false, false}, {true, true}}); + auto b = builder.ConstantR2({{false, true}, {false, true}}); + auto out = builder.Or(a, b); + + Array2D expected_array({{false, true}, {true, true}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.LogicalOr(a, b); + auto out = builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalNot) { +XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, -1, 8}); + auto b = builder.ConstantR1({5, -7, 4}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {5, -1, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, -1}, {8, 8}}); + auto b = builder.ConstantR2({{5, -7}, {4, 1}}); + auto out = builder.Or(a, b); + + Array2D expected_array({{5, -1}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, 1, 8}); + auto b = builder.ConstantR1({5, 7, 4}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {5, 7, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, 1}, {8, 8}}); + auto b = builder.ConstantR2({{5, 7}, {4, 1}}); + auto out = builder.Or(a, b); + + Array2D expected_array({{5, 7}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, true, true, false}); - auto out = builder.LogicalNot(a); + auto out = builder.Not(a); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalNotZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{false, true}, {true, false}}); + auto out = builder.Not(a); + + Array2D expected_array({{true, false}, {false, true}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); - auto out = builder.LogicalNot(a); + auto out = builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-1, 0, 1}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {0, -1, -2}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{-1, 0}, {1, 8}}); + auto out = builder.Not(a); + + Array2D expected_array({{0, -1}, {-2, -9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, 4294967295}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {4294967295, 0}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, 4294967295}, {1, 4294967294}}); + auto out = builder.Not(a); + + Array2D expected_array({{4294967295, 0}, {4294967294, 1}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({static_cast(0x12345678), + static_cast(0xF0001000), 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15}); + auto out = builder.ShiftLeft(a, b); + + ComputeAndCompareR1( + &builder, + {static_cast(0x23456780), 0x00100000, 0x4, 0x180, 2523136}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2}); + auto out = builder.ShiftRightArithmetic(a, b); + + ComputeAndCompareR1(&builder, + {static_cast(0xF9234567), + static_cast(0x00100010), 0, 0, 19}, + {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5}); + auto out = builder.ShiftRightLogical(a, b); + + ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0x12345678, 0xF0001000, 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15}); + auto out = builder.ShiftLeft(a, b); + + ComputeAndCompareR1( + &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2}); + auto out = builder.ShiftRightArithmetic(a, b); + + ComputeAndCompareR1(&builder, {0xF9234567, 0x00100010, 0, 0, 19}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5}); + auto out = builder.ShiftRightLogical(a, b); + + ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 505fa059f28599c7c934749e06be8a7185c66cae..03f5e08315bfed2bcb43ebb7098aaa0b97228605 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -159,7 +159,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { } // Tests implicit broadcasting of PREDs. -XLA_TEST_F(BroadcastSimpleTest, LogicalAnd2DTo3D_Pred) { +XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { ComputationBuilder b(client_, TestName()); Array2D x_vals(2, 1); @@ -174,7 +174,7 @@ XLA_TEST_F(BroadcastSimpleTest, LogicalAnd2DTo3D_Pred) { ComputationDataHandle x, y; auto x_data = CreateR2Parameter(x_vals, 0, "x", &b, &x); auto y_data = CreateR3Parameter(y_vals, 1, "y", &b, &y); - b.LogicalAnd(x, y, /*broadcast_dimensions=*/{1, 2}); + b.And(x, y, /*broadcast_dimensions=*/{1, 2}); Array3D expected(2, 2, 1); expected(0, 0, 0) = false; diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index a60d3e50bd4dc78ed8715f8d7814668b95f3d38a..065bce7e3146c93568bbce2b0e7e23ddddc4ea31 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -254,7 +254,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout) { TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - if (ShapeUtil::ElementIsFloating(expected.shape())) { + if (ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())) { LOG(WARNING) << "performing exact comparison of floating point numbers"; } else { TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) || @@ -282,7 +283,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ComputationBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout) { - TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape())); + TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); auto expect_near = [&](const Literal& actual, const string& error_message) { LiteralTestUtil::ExpectNear(expected, actual, error, error_message); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 7fe1445b94097f762b777fc6936a0a1ab5a726c8..7cfc276ec19e3b177f87a08e716cb34b7676dd6b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -361,8 +361,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( ComputationBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -384,8 +385,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( ComputationBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -407,8 +409,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4( ComputationBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 83882ca75e93ee9edec8e292991b53f1af57bb62..b0a63bccbb93f226175beff2e30e2a243fdca1d3 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -39,7 +39,8 @@ class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {}; // Tests the convolution operation with invalid input dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 2, 0, 2, 2, 3, 0, 1, 2, + 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("input are not unique")); @@ -48,7 +49,8 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { // Tests the convolution operation with invalid weight dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 2, 3, 2, 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 1, 0, 1, 2, 3, 2, 3, 2, + 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("weight are not unique")); @@ -73,14 +75,18 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, ConvolutionDimensionNumbers dim_nums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); // Swap batch_dimension and feature_dimension. - int64 tmp = dim_nums.batch_dimension(); - dim_nums.set_batch_dimension(dim_nums.feature_dimension()); - dim_nums.set_feature_dimension(tmp); + int64 old_input_batch_dim = dim_nums.input_batch_dimension(); + int64 old_output_batch_dim = dim_nums.output_batch_dimension(); + dim_nums.set_input_batch_dimension(dim_nums.input_feature_dimension()); + dim_nums.set_output_batch_dimension(dim_nums.output_feature_dimension()); + dim_nums.set_input_feature_dimension(old_input_batch_dim); + dim_nums.set_output_feature_dimension(old_output_batch_dim); // Swap kernel_input_feature_dimension and kernel_output_feature_dimension. - tmp = dim_nums.kernel_input_feature_dimension(); + int64 old_kernel_input_feature_dim = + dim_nums.kernel_input_feature_dimension(); dim_nums.set_kernel_input_feature_dimension( dim_nums.kernel_output_feature_dimension()); - dim_nums.set_kernel_output_feature_dimension(tmp); + dim_nums.set_kernel_output_feature_dimension(old_kernel_input_feature_dim); builder.ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid, dim_nums); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 7d06cce0c8f82e4a1c4fb847638613594257b80f..0cc2e5fb7e655884f3334426a684dd3ce00d4052 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -418,11 +418,13 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { // Tensorflow dimension numbers for 3D convolution. ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); - dnums.set_feature_dimension(4); + dnums.set_input_feature_dimension(4); + dnums.set_output_feature_dimension(4); dnums.add_kernel_spatial_dimensions(0); dnums.add_kernel_spatial_dimensions(1); dnums.add_kernel_spatial_dimensions(2); @@ -469,10 +471,12 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { // Tensorflow dimension numbers for 2D convolution. ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); dnums.add_kernel_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(2); @@ -504,25 +508,41 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { error_spec_); } -XLA_TEST_F(ConvolutionTest, Convolve1D_Valid) { +struct Convolve1DTestParam { + int64 input_feature; + int64 output_feature; + int64 batch; + int64 window_size; + int64 num_windows; +}; + +class Convolve1D1WindowTest + : public ConvolutionTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) { ComputationBuilder builder(client_, TestName()); - int64 output_feature = 1; - int64 input_feature = 64; - int64 batch = 1; - int64 length = 1; - std::vector input_dims = {batch, 4 + length - 1, input_feature}; - std::vector filter_dims = {4, input_feature, output_feature}; + int64 input_feature = GetParam().input_feature; + int64 output_feature = GetParam().output_feature; + int64 batch = GetParam().batch; + int64 num_windows = GetParam().num_windows; + int64 window_size = GetParam().window_size; + std::vector input_dims = {batch, window_size + num_windows - 1, + input_feature}; + std::vector filter_dims = {window_size, input_feature, output_feature}; Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims); { auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - // Tensorflow dimension numbers for 2D convolution. + // Tensorflow dimension numbers for 1D convolution. ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); - dnums.set_feature_dimension(2); + dnums.set_input_feature_dimension(2); + dnums.set_output_feature_dimension(2); dnums.add_kernel_spatial_dimensions(0); dnums.set_kernel_input_feature_dimension(1); dnums.set_kernel_output_feature_dimension(2); @@ -532,28 +552,57 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_Valid) { } std::vector input_elems(ShapeUtil::ElementsIn(input_shape), 1.0); - // std::iota(input_elems.begin(), input_elems.end(), 1.0f); auto input_r1 = Literal::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), 1.0); - // std::iota(filter_elems.begin(), filter_elems.end(), 1.0f); auto filter_r1 = Literal::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - std::vector expect_elems(batch * output_feature * length, 256); + std::vector expect_elems(batch * output_feature * num_windows, + window_size * input_feature); auto expected_r1 = Literal::CreateR1(expect_elems); - auto expected_r4 = - expected_r1->Reshape({batch, length, output_feature}).ConsumeValueOrDie(); + auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature}) + .ConsumeValueOrDie(); - auto input_literal = client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + auto input_literal = client_->TransferToServer(*input_r3).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + client_->TransferToServer(*filter_r3).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, *expected_r3, {input_literal.get(), filter_literal.get()}, error_spec_); } +INSTANTIATE_TEST_CASE_P( + Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTest, + ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2}, + Convolve1DTestParam{160, 1, 1, 5, 1}, + Convolve1DTestParam{24, 1, 1, 20, 1}, + Convolve1DTestParam{30, 1, 1, 20, 1}, + Convolve1DTestParam{23, 1, 1, 20, 20}, + Convolve1DTestParam{25, 1, 1, 20, 1}, + Convolve1DTestParam{24, 1, 1, 10, 5}, + Convolve1DTestParam{160, 1, 1, 10, 1}, + Convolve1DTestParam{255, 1, 1, 3, 1}, + Convolve1DTestParam{130, 1, 1, 1, 3}, + Convolve1DTestParam{64, 1, 1, 1, 1}, + Convolve1DTestParam{128, 1, 1, 1, 1}, + Convolve1DTestParam{139, 1, 1, 128, 1}, + Convolve1DTestParam{1, 10, 10, 1, 10}, + Convolve1DTestParam{1, 10, 130, 1, 2}, + Convolve1DTestParam{1, 10, 130, 1, 1}, + Convolve1DTestParam{1, 64, 64, 1, 10}, + Convolve1DTestParam{1, 65, 65, 1, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{128, 128, 128, 128, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{2, 2, 2, 2, 1}, + Convolve1DTestParam{161, 1, 1, 10, 1}, + Convolve1DTestParam{900, 1, 1, 10, 1}, + Convolve1DTestParam{640, 3, 3, 128, 1}) + +); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 145918db3e5e57c39054706d53bbfb7648af3143..9b36e3722b8f8a5d01c426425fdfb0c4b9ae3a16 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -974,10 +974,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); @@ -1014,10 +1016,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); @@ -1054,10 +1058,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); @@ -1091,10 +1097,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 342478bc744273be9deb8b750b5a6a47b7d9f91b..74f73a1ddc15be033e52b0b45f9961e5dc3a1ecb 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -31,19 +32,19 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" - -extern "C" void TF_EXPORT R0F32Add2(float* out, float** in) { +namespace { +void R0F32Add2(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*)); *out = **in + 2.0f; } -extern "C" void TF_EXPORT R2F32ReduceSum(float* out, float** in) { +void R2F32ReduceSum(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; *out = array[0] + array[1] + array[2] + array[3]; } -extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) { +void Add1ToValues(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; out[0] = array[0] + 1; @@ -51,6 +52,11 @@ extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) { out[2] = array[2] + 1; out[3] = array[3] + 1; } +} // namespace + +REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); +REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); +REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); namespace xla { namespace { diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 224aa57899d04eb8309b2337bb8fc936a81d350f..cf089d748dcd4f5db637ff9087c5fbc504c82572 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -347,7 +347,7 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); } -TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { constexpr bool kLhsRowMajor = true; constexpr bool kRhsRowMajor = true; TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); @@ -357,7 +357,11 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { TestNonsquareMatrixDot(); } -TEST_F(DotOperationTest, ConcurrentMatMul) { +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) { + TestNonsquareMatrixDot(); +} + +XLA_TEST_F(DotOperationTest, ConcurrentMatMul) { ComputationBuilder builder(client_, TestName()); auto matrix1 = builder.ConstantR2({{1.0, 2.0}, {3.0, 4.0}}); auto matrix2 = builder.ConstantR2({{5.0, 6.0}, {7.0, 8.0}}); diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 3bf9ccb19745b9e91d99614792dbec0443818f2b..a8f6488996087b57e3121ce2c7de918070950c72 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -17,8 +17,12 @@ limitations under the License. #include #include #include +#include #include +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" @@ -37,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -250,6 +255,42 @@ XLA_TEST_F(FusionTest, Parameter) { ErrorSpec(1e-4)); } +XLA_TEST_F(FusionTest, RandomizedParallelPartition) { + // Tests parallel partitioning of a fusion instruction. + // Create shape with random outer dimension size to generate random parallel + // partition counts for each test run. + const int seed = tensorflow::testing::RandomSeed(); + LOG(INFO) << "RandomizedParallelPartition seed: " << seed; + std::mt19937 generator(seed); + std::uniform_int_distribution distribution(128, 1024); + const int64 rand_dim0_size = distribution(generator); + const int64 dim1_size = 1024; + Shape shape = + ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); + // Build simple fusion computation: y = x^2 (elementwise). + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = CreateNewModule(); + + auto two = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto x = + builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {})); + auto y = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x)); + + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two}, + HloInstruction::FusionKind::kLoop); + // Compute result. + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + // Every element of result should be y = x^2 = 4.0. + for (int i = 0; i < rand_dim0_size; ++i) { + for (int j = 0; j < dim1_size; ++j) { + EXPECT_EQ(4.0, result->Get({i, j})); + } + } +} + XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); @@ -722,47 +763,104 @@ void BM_ParallelFusion(int num_iters) { auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); StreamExecutorMemoryAllocator allocator(platform, executors); - const int64 intra_op_parallelism_threads = 16; + const int64 intra_op_parallelism_threads = 24; xla::LocalClientOptions client_options; client_options.set_platform(platform); client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads); auto client = ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); - const int64 dim_size = 1024; - // Create a simple fusable elementwise computation. + auto* transfer_manager = + TransferManager::GetForPlatform(platform).ValueOrDie(); + int device_ordinal = client->default_device_ordinal(); + + // Computation shape parameters. + const int64 param0_dim0 = 1024; + const int64 param0_dim1 = 1024; + const int64 param1_dim0 = 1024; + const int64 param1_dim1 = 1024; + const int64 param2_dim0 = 1024; + const int64 param2_dim1 = 1024; + + // Create computation. ComputationBuilder builder(client, "ParallelFusion"); - Shape input_shape = ShapeUtil::MakeShape(F32, {dim_size, dim_size}); - auto input0 = builder.Broadcast(builder.ConstantR0(1.5f), - AsInt64Slice(input_shape.dimensions())); - auto input1 = builder.Broadcast(builder.ConstantR0(2.0f), - AsInt64Slice(input_shape.dimensions())); - auto input2 = builder.Broadcast(builder.ConstantR0(3.0f), - AsInt64Slice(input_shape.dimensions())); - auto x = builder.Mul(input0, input1); - auto y = builder.Add(x, input2); + Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); + auto param0 = builder.Parameter(0, shape0, "param0"); + Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); + auto param1 = builder.Parameter(1, shape1, "param1"); + Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1}); + auto param2 = builder.Parameter(2, shape2, "param2"); + + auto x = builder.Mul(param0, param1); + auto y = builder.Add(x, param2); auto computation = builder.Build().ConsumeValueOrDie(); + // Transfer literals to device. + auto buffer0 = + ScopedShapedBuffer::Allocate(shape0, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param0_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param0_literal, buffer0->mutable_buffer({}))); + + auto buffer1 = + ScopedShapedBuffer::Allocate(shape1, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param1_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param1_literal, buffer1->mutable_buffer({}))); + + auto buffer2 = + ScopedShapedBuffer::Allocate(shape2, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param2_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param2_literal, buffer2->mutable_buffer({}))); + + // Build executable. std::unique_ptr executable = - client->Compile(computation, {}, ExecutableBuildOptions()) + client + ->Compile(computation, + {&buffer0->shape(), &buffer1->shape(), &buffer2->shape()}, + ExecutableBuildOptions()) .ConsumeValueOrDie(); - // Run some warm-up executions. + se::Stream stream(executors[client->default_device_ordinal()]); + stream.Init(); + + // Initialize thread pool. + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", + intra_op_parallelism_threads); + tensorflow::EigenThreadPoolWrapper tp(&pool); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + // Initialize ExecutableRunOptions. ExecutableRunOptions options; - options.set_allocator(&allocator); + options.set_allocator(&allocator).set_stream(&stream); + options.set_intra_op_thread_pool(&device); + + // Run some warm-up executions. const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); ASSERT_TRUE(result.ok()); } // Run benchmark. - tensorflow::testing::BytesProcessed(static_cast(num_iters) * dim_size * - dim_size * sizeof(float)); + const int64 total_bytes = param0_dim0 * param0_dim0 + + param1_dim0 * param1_dim0 + + param2_dim0 * param2_dim0; + tensorflow::testing::BytesProcessed(static_cast(num_iters) * + total_bytes * sizeof(float)); tensorflow::testing::UseRealTime(); tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); ASSERT_TRUE(result.ok()); } } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 26513d6ce8e0b8896e9f9838ecf28f1ed5bbb383..d73c05ff92578209143e0679558848160cae99bd 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -19,24 +19,9 @@ limitations under the License. #include #include -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/shape_layout.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -45,22 +30,6 @@ namespace se = ::perftools::gputools; namespace xla { -// Define this in .cc file to avoid having to include eigen or forward declare -// these types in the header. -struct HloTestBase::EigenThreadPoolWrapper { - std::unique_ptr pool; - std::unique_ptr device; -}; - -HloTestBase::HloTestBase() {} - -HloTestBase::~HloTestBase() { - // Deallocate all the memory allocated during the tests. - for (auto& allocation : allocations_) { - backend().default_stream_executor()->Deallocate(&allocation); - } -} - /* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; @@ -80,98 +49,25 @@ StatusOr HloTestBase::Execute( tensorflow::gtl::ArraySlice arguments, Shape* result_shape) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - backend().compiler()->Compile(std::move(module), - backend().default_stream_executor())); - - se::Stream stream(backend().default_stream_executor()); - stream.Init(); - - ExecutableRunOptions run_options; - run_options.set_stream(&stream); - run_options.set_allocator(backend().memory_allocator()); - run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); - run_options.set_intra_op_thread_pool( - backend().eigen_intra_op_thread_pool_device()); - - HloExecutionProfile hlo_execution_profile; - ServiceExecutableRunOptions service_run_options( - run_options, backend().StreamBorrower(), - backend().inter_op_thread_pool()); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase result, - executable->ExecuteOnStream(&service_run_options, arguments, - &hlo_execution_profile)); - TF_RET_CHECK(stream.BlockHostUntilDone()); - - allocations_.push_back(result); - - *result_shape = executable->result_shape(); - - if (ShapeUtil::IsTuple(*result_shape)) { - // We must record element buffers of tuples as well to avoid leaks. - DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); - TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - backend().transfer_manager()->ShallowCopyTupleFromDevice( - backend().default_stream_executor(), result, *result_shape)); - - // A tuple may contain the same buffer in more than one element. Keep track - // of the buffers already added to avoid duplicates in allocations_. - std::set added_opaques; - for (auto element_buffer : element_buffers) { - if (added_opaques.count(element_buffer.opaque()) == 0) { - CHECK(element_buffer.opaque() != nullptr); - added_opaques.insert(element_buffer.opaque()); - allocations_.push_back(element_buffer); - } - } - } - - return result; + return runner_.Execute(std::move(module), arguments, result_shape); } se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { - // Allocate memory on the device using the stream executor. - int64 allocation_size = - backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); - se::DeviceMemoryBase allocation = - backend().default_stream_executor()->AllocateArray( - allocation_size); - allocations_.push_back(allocation); - - TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice( - backend().default_stream_executor(), literal, &allocation)); - - return allocation; + return runner_.TransferToDevice(literal).ValueOrDie(); } std::unique_ptr HloTestBase::TransferFromDevice( const Shape& shape, se::DeviceMemoryBase device_base) { - auto literal = MakeUnique(); - TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice( - backend().default_stream_executor(), device_base, shape, shape, - literal.get())); - return literal; + return runner_.TransferFromDevice(shape, device_base).ValueOrDie(); } std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { - Shape result_shape; - se::DeviceMemoryBase device_base = - Execute(std::move(module), arguments, &result_shape).ValueOrDie(); - return TransferFromDevice(result_shape, device_base); + return runner_.ExecuteAndTransfer(std::move(module), arguments).ValueOrDie(); } -Backend& HloTestBase::backend() { - if (!backend_) { - backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie(); - VLOG(1) << "executing on platform " << backend().platform()->Name(); - } - return *backend_; -} +Backend& HloTestBase::backend() { return runner_.backend(); } /* static */ string HloTestBase::TestName() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 275f1f5c7baa11245186d119f5b38b4d02b84566..7f068dce36be3546298de2f06bf6d33446d07ca2 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -21,12 +21,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -39,10 +39,9 @@ namespace xla { // building a graph of HLO instructions to run. class HloTestBase : public ::testing::Test { protected: - struct EigenThreadPoolWrapper; - HloTestBase(); + HloTestBase() {} - ~HloTestBase() override; + ~HloTestBase() override {} // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug @@ -102,23 +101,12 @@ class HloTestBase : public ::testing::Test { static string TestName(); - // Creates (if necessary) and returns the default backend. If creation fails, - // crashes the program. - // - // This creates the backend lazily so it's possible to instantiate an - // HloTestBase in a program without any backends linked in. + // Returns the backend owned by the HloRunner. Backend& backend(); - // This vector contains handles of all the device memory allocations performed - // by the test. These are deallocated on destruction of the test object. - std::vector allocations_; + HloRunner runner_; ErrorSpec error_spec_{0.0001}; - - std::unique_ptr thread_pool_wrapper_; - - private: - std::unique_ptr backend_; // Lazily populated. Access via backend(). }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc new file mode 100644 index 0000000000000000000000000000000000000000..31060b9e80fcd50aefdedca27c70ec8a9b8be743 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +/*static*/ int64 HloVerifiedTestBase::DefaultShapeSize(const Shape& shape) { + constexpr int64 kPointerSize = sizeof(void*); + if (ShapeUtil::IsOpaque(shape)) { + return kPointerSize; + } + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + +HloVerifiedTestBase::HloVerifiedTestBase() : shape_size_fn_(DefaultShapeSize) {} + +HloVerifiedTestBase::~HloVerifiedTestBase() { + // We can't call the ASSERT or EXPECT test macros in destructors, so we + // perform HLO verification in TearDown, and use the CHECK here to ensure + // users don't accidentally override the verification. + CHECK(tear_down_called_) + << "TearDown was never called; subclasses of HloVerifiedTestBase that " + << "override TearDown must call the superclass TearDown."; +} + +void HloVerifiedTestBase::TearDown() { + EXPECT_FALSE(tear_down_called_) + << "TearDown called more than once; it should be called exactly once."; + tear_down_called_ = true; + if (module_) { + HloVerifier verifier(shape_size_fn_); + xla::StatusOr mutated = verifier.Run(module_.get()); + if (!mutated.ok()) { + ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); + } else { + EXPECT_FALSE(mutated.ValueOrDie()) + << "HloVerifier should never mutate the HloModule"; + } + } + HloTestBase::TearDown(); +} + +HloModule& HloVerifiedTestBase::module() { + if (!module_) { + module_ = CreateNewModule(); + } + return *module_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h new file mode 100644 index 0000000000000000000000000000000000000000..b3d6b5af3b46f932707abf309669d23c327d1334 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { + +// A base class for HLO tests that stores a default HloModule, and automatically +// performs verification on that module on tear-down. +class HloVerifiedTestBase : public HloTestBase { + public: + // Returns the size in bytes of the given shape, using a default pointer size. + static int64 DefaultShapeSize(const Shape& shape); + + protected: + HloVerifiedTestBase(); + ~HloVerifiedTestBase() override; + + // Performs verification on the default HloModule returned by module(). + // Automatically called by the testing framework for each test. + // + // REQUIRED: subclasses that override TearDown() must call this explicitly. + void TearDown() override; + + // Returns the default HloModule, lazily creating it if necessary via + // HloTestBase::CreateNewModule(). + HloModule& module(); + + // Sets the shape-size function used during hlo verification. If this isn't + // called, DefaultShapeSize is used instead. + void SetShapeSizeFn(std::function shape_size_fn) { + shape_size_fn_ = std::move(shape_size_fn); + } + + private: + std::unique_ptr module_; // Lazily populated. Access via module(). + std::function shape_size_fn_; + bool tear_down_called_ = false; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 4d8b50fbbf715e8d491667ecb4f4f336ef2d8a68..95a52ecd2f5cfc97ec1ccba7d1b7ca6257a8267e 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -39,28 +39,60 @@ limitations under the License. namespace xla { -/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, - const Shape& actual) { - ASSERT_EQ(ShapeUtil::IsTuple(expected), ShapeUtil::IsTuple(actual)); +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( + const Shape& expected, const Shape& actual) { + if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { + return ::testing::AssertionFailure() + << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected) + << " got: " << ShapeUtil::HumanString(actual); + } if (ShapeUtil::IsTuple(expected)) { - ASSERT_EQ(ShapeUtil::TupleElementCount(expected), - ShapeUtil::TupleElementCount(actual)); + if (ShapeUtil::TupleElementCount(expected) != + ShapeUtil::TupleElementCount(actual)) { + return ::testing::AssertionFailure() + << "want tuple element count: " + << ShapeUtil::TupleElementCount(expected) + << " got tuple element count: " + << ShapeUtil::TupleElementCount(actual); + } for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + ::testing::AssertionResult result = + EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + if (!result) { + return result; + } } } else { - ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual)); - ASSERT_EQ(expected.element_type(), actual.element_type()) - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size()); + if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + return ::testing::AssertionFailure() + << "want rank of: " << ShapeUtil::HumanString(expected) + << " got rank of: " << ShapeUtil::HumanString(actual); + } + if (expected.element_type() != actual.element_type()) { + return ::testing::AssertionFailure() + << PrimitiveType_Name(expected.element_type()) << " vs " + << PrimitiveType_Name(actual.element_type()); + } + if (expected.dimensions_size() != actual.dimensions_size()) { + return ::testing::AssertionFailure() + << "want dimensions_size " << expected.dimensions_size() + << " got dimensions_size " << actual.dimensions_size(); + } for (int i = 0; i < expected.dimensions_size(); ++i) { - ASSERT_EQ(expected.dimensions(i), actual.dimensions(i)) - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); + if (expected.dimensions(i) != actual.dimensions(i)) { + return ::testing::AssertionFailure() + << "mismatch in dimension #" << i + << " expected: " << ShapeUtil::HumanString(expected) + << " actual: " << ShapeUtil::HumanString(actual); + } } } + return ::testing::AssertionSuccess(); +} + +/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, + const Shape& actual) { + ASSERT_TRUE(EqualShapes(expected, actual)); } /* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts( @@ -124,6 +156,15 @@ template <> ::testing::AssertionResult CompareEqual(double lhs, double rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } +template <> +::testing::AssertionResult CompareEqual(complex64 lhs, + complex64 rhs) { + auto res = CompareEqual(lhs.real(), rhs.real()); + if (!res) { + return res; + } + return CompareEqual(lhs.imag(), rhs.imag()); +} // A recursive function which iterates through every index of expected and // actual literal and compares their values elementwise. Returns true if all @@ -203,6 +244,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, case F64: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; + case C64: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; case TUPLE: { bool tuple_match = true; for (int i = 0; i < actual.tuple_literals_size(); ++i) { @@ -263,7 +307,14 @@ class NearComparator { VLOG(1) << "actual:"; XLA_VLOG_LINES(1, actual.ToString()); - LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape()); + // If the shapes mismatch, we simply fail the expectation instead of + // printing out data, as it's a type error rather than a value error. + ::testing::AssertionResult equal_shapes = + LiteralTestUtil::EqualShapes(expected.shape(), actual.shape()); + if (!equal_shapes) { + EXPECT_TRUE(equal_shapes); + return false; + } // Set up members used during the comparison. num_miscompares_ = 0; @@ -286,6 +337,9 @@ class NearComparator { case F64: ExpectLiteralsNear(expected, actual, 0); break; + case C64: + ExpectLiteralsNear(expected, actual, 0); + break; default: LOG(FATAL) << "Unsupported primitive type in near comparator: " << PrimitiveType_Name(expected.shape().element_type()) @@ -326,6 +380,19 @@ class NearComparator { } private: + template + bool NanMismatch(NativeT lhs, NativeT rhs) { + return std::isnan(lhs) != std::isnan(rhs); + } + + template + void ExpectNear(NativeT expected, NativeT actual, + const ::testing::Message& message) { + EXPECT_NEAR(expected, actual, error_.abs) + << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" + << message; + } + // EXPECTs that the two given scalar values are within the error bound. Keeps // track of how many mismatches have occurred to keep the size of the output // manageable. @@ -351,7 +418,7 @@ class NearComparator { "index %s abs_diff %f rel_err %f", LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff, rel_err); - bool nan_mismatch = std::isnan(actual) != std::isnan(expected); + bool nan_mismatch = NanMismatch(expected, actual); bool mismatch = (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel)); if (mismatch) { @@ -359,11 +426,12 @@ class NearComparator { abs_expected_miscompare_sum_ += std::abs(expected); const int64 kMaxFailures = 2; if (num_miscompares_ < kMaxFailures) { - EXPECT_NEAR(expected, actual, error_.abs) - << "mismatch at index " + ::testing::Message msg; + msg << "mismatch at index " << LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff " << abs_diff << " rel err " << rel_err << " failure #" << num_miscompares_; + ExpectNear(expected, actual, msg); } else if (num_miscompares_ == kMaxFailures) { LOG(ERROR) << "reached max 'loud' failure count; silently proceeding..."; @@ -431,6 +499,23 @@ class NearComparator { std::vector max_abs_multi_index_; }; +template <> +bool NearComparator::NanMismatch(complex64 lhs, complex64 rhs) { + return std::isnan(lhs.real()) != std::isnan(rhs.real()) || + std::isnan(lhs.imag()) != std::isnan(rhs.imag()); +} + +template <> +void NearComparator::ExpectNear(complex64 expected, complex64 actual, + const ::testing::Message& message) { + EXPECT_NEAR(expected.real(), actual.real(), error_.abs) + << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" + << message; + EXPECT_NEAR(expected.imag(), actual.imag(), error_.abs) + << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" + << message; +} + } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index f645c4e8dcda73806a4204876716b93aa5fb7185..467d44b857b74d2a38e9b3f8a32a9b1d39a4a10d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -50,6 +50,8 @@ class LiteralTestUtil { public: // Asserts that the given shapes have the same rank, dimension sizes, and // primitive types. + static ::testing::AssertionResult EqualShapes(const Shape& expected, + const Shape& actual); static void AssertEqualShapes(const Shape& expected, const Shape& actual); // Asserts that the provided shapes are equal as defined in AssertEqualShapes diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index c74213f7f9198741770713aa950e78f2e5ec951d..329b53012f58c8d084cc05f9a567a8aa432c4a3a 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" @@ -859,6 +860,31 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { Literal::CreateR0(123456789000LL).get()})); } +// TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel. +// 2017-10-18. +XLA_TEST_F(LocalClientExecuteTest, + DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(InfeedOutfeedTest))) { + ComputationBuilder builder(local_client_, TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {3}); + auto in = builder.Infeed(shape); + auto constant = builder.ConstantR1({1.0f, 2.0f, 3.0f}); + auto sum = builder.Add(in, constant); + builder.Outfeed(sum, shape, /*outfeed_config=*/""); + + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", + [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); + + ASSERT_IS_OK(local_client_->TransferToInfeed( + *Literal::CreateR1({-5.0, 123.0, 42.0}))); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + local_client_->TransferFromOutfeed(&shape)); + + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); +} + // Benchmark that measures the overhead of the LocalClient API when running a // trivial computation void BM_LocalClientOverhead(int num_iters) { diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 05e282d2081a736fcae1d6a279cdcc37682696f7..c11e1df0a7890a6c3aada5ff47494b42fdaf3b9d 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -90,6 +90,9 @@ int64 TestAllocator::deallocation_count(int device_ordinal) const { /* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator( perftools::gputools::Platform* platform) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + tensorflow::mutex_lock lock(mu); + if (allocator_ == nullptr) { allocator_ = new TestAllocator( platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie() diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 17c25adfef9ea2cbe715cd82a199f479e53529b8..3edfcb656ed8278d403103f0cfd820a10892476a 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -128,8 +128,8 @@ class LocalClientTestBase : public ::testing::Test { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } - // The allocator must live as long as the service which lives until the end of - // the process, so make the allocator static. + // The allocator must live as long as the service, which lives until the end + // of the process. So make the allocator static. static TestAllocator* allocator_; perftools::gputools::StreamExecutor* stream_executor_; diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 2271f32c5946f3d3e7e6b43b089e68ab3101b61b..7bc3185c367f076c9a7d211c9799557e1a91d92f 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -120,10 +120,10 @@ class ReduceTest : public ClientLibraryTestBase { Computation reduce; if (and_reduce) { init_value = builder.ConstantR0(true); - reduce = CreateScalarLogicalAndComputation(&builder); + reduce = CreateScalarAndComputation(&builder); } else { init_value = builder.ConstantR0(false); - reduce = CreateScalarLogicalOrComputation(&builder); + reduce = CreateScalarOrComputation(&builder); } builder.Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); @@ -457,7 +457,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); - auto log_ = builder.Log(input); + auto log_ = builder.Tanh(input); auto reshape = builder.Reshape(log_, {rows, cols}); builder.Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0}); @@ -473,7 +473,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { for (int64 colno = 0; colno < cols / 2; ++colno) { float column_sum = 0; for (int64 rowno = 0; rowno < rows; ++rowno) { - column_sum += log(input_data(rowno, major, colno)); + column_sum += tanh(input_data(rowno, major, colno)); } expected.push_back(column_sum); } @@ -502,8 +502,8 @@ XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { ComputationBuilder builder(client_, TestName()); auto add = CreateScalarAddComputation(F32, &builder); auto scalar = builder.ConstantR0(42.0); - auto broacasted = builder.Broadcast(scalar, {500, 500}); - builder.Reduce(broacasted, builder.ConstantR0(0.0f), add, {0, 1}); + auto broadcasted = builder.Broadcast(scalar, {500, 500}); + builder.Reduce(broadcasted, builder.ConstantR0(0.0f), add, {0, 1}); float expected = 42.0f * static_cast(500 * 500); ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -514,8 +514,8 @@ XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) { ComputationBuilder builder(client_, TestName()); auto max = CreateScalarMaxComputation(F32, &builder); auto scalar = builder.ConstantR0(42.0); - auto broacasted = builder.Broadcast(scalar, {500, 500}); - builder.Reduce(broacasted, builder.ConstantR0(0.0f), max, {0, 1}); + auto broadcasted = builder.Broadcast(scalar, {500, 500}); + builder.Reduce(broadcasted, builder.ConstantR0(0.0f), max, {0, 1}); float expected = 42.0f; ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -729,16 +729,14 @@ XLA_TEST_F(ReduceTest, VectorizedReduce_Min) { std::numeric_limits::max()); } -XLA_TEST_F(ReduceTest, VectorizedReduce_LogicalAnd) { - RunVectorizedReduceTestForType(CreateScalarLogicalAndComputation, - [](bool a, bool b) { return a && b; }, - true); +XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) { + RunVectorizedReduceTestForType( + CreateScalarAndComputation, [](bool a, bool b) { return a && b; }, true); } -XLA_TEST_F(ReduceTest, VectorizedReduce_LogicalOr) { - RunVectorizedReduceTestForType(CreateScalarLogicalOrComputation, - [](bool a, bool b) { return a || b; }, - false); +XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) { + RunVectorizedReduceTestForType( + CreateScalarOrComputation, [](bool a, bool b) { return a || b; }, false); } class ReduceR3ToR2Test : public ReduceTest, diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 7b7f2687286916aff9c47a7c165619bbe84368e8..6c9b62b48d8bb2ad93b2ce98839e5e52d8eaa8cc 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -76,6 +76,20 @@ class ReduceWindowTest : public ClientLibraryTestBase { ComputationBuilder builder_; }; +TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { + const auto input = builder_.ConstantR1({1, 1, 1, 1}); + const auto init_value = builder_.ConstantR0(0); + TF_ASSERT_OK(builder_.first_error()); + builder_.ReduceWindow(input, init_value, + CreateScalarAddComputation(F32, &builder_), + /*window_dimensions=*/{1, 2}, + /*window_strides=*/{1}, Padding::kValid); + ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) + << builder_.first_error(); + ASSERT_THAT(builder_.first_error().error_message(), + ::testing::HasSubstr("Want input dimensions size")); +} + TEST_F(ReduceWindowTest, Min3In5Stride2) { const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); ReduceWindowMin(input, {3}, {2}, Padding::kValid); diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index bb7160e3a03053a4f3d8da712c1424e50f37dfeb..72c68f24a0a954deb0564e9a0e924edfaf5b5484 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -47,7 +47,7 @@ class ReshapeTest : public ClientLibraryTestBase { }; // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial1x1) { +XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2({{1.0}}); builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); @@ -55,6 +55,22 @@ XLA_TEST_F(ReshapeTest, Trivial1x1) { ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); } +XLA_TEST_F(ReshapeTest, CollapseTrivialR1EmptyDims) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({1.0}); + builder.Collapse(/*operand=*/a, /*dimensions=*/{}); + + ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); +} + +XLA_TEST_F(ReshapeTest, CollapseTrivialR1OnlyDim) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({1.0}); + builder.Collapse(/*operand=*/a, /*dimensions=*/{0}); + + ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); +} + // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 77d1c019f3a23f79237e624dabf8972a6c1d3c72..b5e7570778ffeca66cc15d7cd2b153639637a647 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -459,39 +459,99 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { ComputeAndCompareR0(&builder, 2, {}); } -XLA_TEST_F(ScalarComputationsTest, LogicalAnd) { +XLA_TEST_F(ScalarComputationsTest, AndBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { ComputationBuilder builder(client_, TestName()); - builder.LogicalAnd(builder.ConstantR0(x), - builder.ConstantR0(y)); + builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x && y, {}); } } } -XLA_TEST_F(ScalarComputationsTest, LogicalOr) { +XLA_TEST_F(ScalarComputationsTest, AndS32) { + for (int32 x : {0, 8}) { + for (int32 y : {1, -16}) { + ComputationBuilder builder(client_, TestName()); + builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x & y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, AndU32) { + for (uint32 x : {0, 8}) { + for (uint32 y : {1, 16}) { + ComputationBuilder builder(client_, TestName()); + builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x & y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, OrBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { ComputationBuilder builder(client_, TestName()); - builder.LogicalOr(builder.ConstantR0(x), - builder.ConstantR0(y)); + builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x || y, {}); } } } -XLA_TEST_F(ScalarComputationsTest, LogicalNot) { +XLA_TEST_F(ScalarComputationsTest, OrS32) { + for (int32 x : {0, 8}) { + for (int32 y : {1, -16}) { + ComputationBuilder builder(client_, TestName()); + builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x | y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, OrU32) { + for (uint32 x : {0, 8}) { + for (uint32 y : {1, 16}) { + ComputationBuilder builder(client_, TestName()); + builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x | y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, NotBool) { for (bool x : {false, true}) { ComputationBuilder builder(client_, TestName()); - builder.LogicalNot(builder.ConstantR0(x)); + builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, !x, {}); } } +XLA_TEST_F(ScalarComputationsTest, NotS32) { + for (int32 x : {-1, 0, 1}) { + ComputationBuilder builder(client_, TestName()); + builder.Not(builder.ConstantR0(x)); + + ComputeAndCompareR0(&builder, ~x, {}); + } +} + +XLA_TEST_F(ScalarComputationsTest, NotU32) { + for (uint32 x : {0, 1, 2}) { + ComputationBuilder builder(client_, TestName()); + builder.Not(builder.ConstantR0(x)); + + ComputeAndCompareR0(&builder, ~x, {}); + } +} + XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { ComputationBuilder builder(client_, TestName()); builder.Select(builder.ConstantR0(true), // The predicate. diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index efae13a43a058b03a45174c8260bce2ed70cb75c..fa4192e9281784a4a3063601afe89fba6a9dac18 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -41,7 +41,11 @@ class UnaryOpTest : public ClientLibraryTestBase { auto arg = builder.ConstantR1({}); auto abs = builder.Abs(arg); - ComputeAndCompareR1(&builder, {}, {}); + if (primitive_util::NativeToPrimitiveType() == C64) { + ComputeAndCompareR1(&builder, {}, {}); + } else { + ComputeAndCompareR1(&builder, {}, {}); + } } template @@ -80,14 +84,58 @@ int UnaryOpTest::inf() { return 2147483647; } +template <> +void UnaryOpTest::AbsTestHelper() { + ComputationBuilder builder(client_, TestName()); + auto arg = builder.ConstantR1({{-2, 0}, + {0, 25}, + {0, 0}, + {-0.3f, 0.4f}, + {0, inf()}, + {-inf(), 0}}); + auto abs = builder.Abs(arg); + + std::unique_ptr expected = + Literal::CreateR1({2, 25, 0, 0.5, inf(), inf()}); + ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); +} + +template <> +void UnaryOpTest::SignTestHelper() { + ComputationBuilder builder(client_, TestName()); + auto arg = builder.ConstantR1( + {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); + auto sign = builder.Sign(arg); + + std::unique_ptr expected = Literal::CreateR1( + {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); + ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); +} + +template <> +void UnaryOpTest::SignAbsTestHelper() { + ComputationBuilder builder(client_, TestName()); + auto arg = + builder.ConstantR1({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); + auto sign = builder.Sign(arg); + auto abs = builder.Abs(arg); + builder.Sub(builder.Mul(sign, builder.ConvertElementType(abs, C64)), arg); + + std::unique_ptr expected = + Literal::CreateR1({0, 0, 0, 0}); + ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); +} + XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) { AbsSize0TestHelper(); AbsSize0TestHelper(); + AbsSize0TestHelper(); } XLA_TEST_F(UnaryOpTest, AbsTestR1) { AbsTestHelper(); AbsTestHelper(); + AbsTestHelper(); } XLA_TEST_F(UnaryOpTest, AbsTestR0) { @@ -98,34 +146,44 @@ XLA_TEST_F(UnaryOpTest, AbsTestR0) { auto absf = builder.Abs(argf); auto argf0 = builder.ConstantR0(-0.0f); auto absf0 = builder.Abs(argf0); - builder.Add(absf0, builder.Add(absf, builder.ConvertElementType( - absi, PrimitiveType::F32))); + auto argc = builder.ConstantR0({-0.3f, 0.4f}); + auto absc = builder.Abs(argc); + builder.Add(builder.Add(absc, absf0), + builder.Add(absf, builder.ConvertElementType(absi, F32))); - ComputeAndCompareR0(&builder, 8.0f, {}); + ComputeAndCompareR0(&builder, 8.5f, {}); } XLA_TEST_F(UnaryOpTest, SignTestR0) { ComputationBuilder builder(client_, TestName()); auto argi = builder.ConstantR0(-5); - auto absi = builder.Sign(argi); + auto sgni = builder.Sign(argi); // -1 auto argf = builder.ConstantR0(-4.0f); - auto absf = builder.Sign(argf); + auto sgnf = builder.Sign(argf); // -1 auto argf0 = builder.ConstantR0(-0.0f); - auto absf0 = builder.Sign(argf0); - builder.Add(absf0, builder.Add(absf, builder.ConvertElementType( - absi, PrimitiveType::F32))); - - ComputeAndCompareR0(&builder, -2.0f, {}); + auto sgnf0 = builder.Sign(argf0); // 0 + auto argc = builder.ConstantR0({-.3, .4}); + auto sgnc = builder.Sign(argc); // (-.6, .8) + builder.Add(sgnc, builder.ConvertElementType( + builder.Add(builder.Add(sgnf0, sgnf), + builder.ConvertElementType(sgni, F32)), + C64)); + + std::unique_ptr expected = + Literal::CreateR0({-2.6f, 0.8f}); + ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, SignTestR1) { SignTestHelper(); SignTestHelper(); + SignTestHelper(); } XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { SignAbsTestHelper(); SignAbsTestHelper(); + SignAbsTestHelper(); } XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index bb2d90fa94abbf52c340d366ddc55f7bdefb6543..71a1b0abee51ba2819daed23208b0da8d5107207 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -169,7 +169,7 @@ TEST_F(WhileTest, WhileWithPredicateResult) { { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); - auto result = builder.LogicalOr(prev, builder.ConstantR0(true)); + auto result = builder.Or(prev, builder.ConstantR0(true)); body = builder.Build().ConsumeValueOrDie(); } @@ -437,7 +437,7 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto pred = builder.GetTupleElement(prev, 1); - auto new_pred = builder.LogicalOr(pred, builder.ConstantR0(true)); + auto new_pred = builder.Or(pred, builder.ConstantR0(true)); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 0451537af777e127df333da8a941a89e6fe315c2..759921dce5acf3cd23a121776f3ab0731c9bb623 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -210,6 +210,18 @@ tf_cc_binary( ], ) +tf_cc_binary( + name = "hlo_proto_to_json", + srcs = ["hlo_proto_to_json.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index aa297ac171d76d73d4c71c7cdfd2c2e2b9fd9a3d..5ede37b8737bd4fa6235464ddeb6382af17c8a80 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -86,9 +86,9 @@ void RealMain(tensorflow::gtl::ArraySlice args) { layouts.push_back(&program_shape->parameters(i)); } StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*has_hybrid_result=*/true); + local_service->CompileExecutable(computation.handle(), layouts, + &program_shape->result(), + /*device_ordinal=*/0); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 2a3a8803283c62d12d8e2d213aa1730e8bd33244..78d8fb1f4330aed899ca917e66fae819a002b3a9 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -61,9 +61,9 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { layouts.push_back(&program_shape->parameters(i)); } StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*has_hybrid_result=*/true); + local_service->CompileExecutable(computation.handle(), layouts, + &program_shape->result(), + /*device_ordinal=*/0); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e02e17db65c0a4220672733be8319e1a0cc4f0f --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: +// hlo_proto_to_json --input_file=some_binary_proto +// --output_file=path_to_dump_output +// +// Reads one serilized Hlo module, convert it into JSON format and dump into +// some output directory. some_binaray_proto is obtained by serializing Hlo +// module to disk using --xla_dump_hlo_proto_to debug optoin. + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +using tensorflow::Env; +using xla::string; + +namespace xla { +namespace tools { + +StatusOr ToJson(const tensorflow::protobuf::Message& message) { + string json_output; + tensorflow::protobuf::util::JsonPrintOptions json_options; + json_options.add_whitespace = true; + json_options.always_print_primitive_fields = true; + auto status = tensorflow::protobuf::util::MessageToJsonString( + message, &json_output, json_options); + if (!status.ok()) { + return InternalError("MessageToJsonString failed: %s", + status.error_message().data()); + } + return json_output; +} + +void RealMain(const string& input, const string& output) { + HloProto hlo_proto; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), input, + &hlo_proto)) + << "Can't open, read, or parse input file " << input; + + auto statusor = ToJson(hlo_proto); + QCHECK(statusor.ok()) << "Error converting " << input << " to JSON." + << statusor.status(); + + TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output, + statusor.ValueOrDie())); +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + string input_file, output_file; + const std::vector flag_list = { + tensorflow::Flag("input_file", &input_file, "file to convert."), + tensorflow::Flag("output_file", &output_file, "converted file"), + }; + const string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(parse_ok && argc == 1) << "\n" << usage; + + QCHECK(!input_file.empty()) << "--input_file is required"; + QCHECK(!output_file.empty()) << "--output_file is required"; + + xla::tools::RealMain(input_file, output_file); + + return 0; +} diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..c84ca9fc833881ce49bcaad5dd85394145151912 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -0,0 +1,84 @@ +# Build file for the Hlo parser. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [":friends"], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + "hlo_token.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo_lexer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_parser", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2feaa49db86ea700cab0b794ec441b95ac03b468 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -0,0 +1,85 @@ +# HloModule string syntax + +TODO: Support all subcomputations (for fusion, reduce, ...). + +TODO: Support all extra attributes, e.g. dimensions, strides. + +```yacc +hlo_module + : 'HloModule' name computations + ; + +computations + : computation + | computation computations + ; + +computation + : 'ENTRY' name param_list '->' shape instruction_list + | name param_list '->' shape instruction_list + ; + +instruction_list + : '{' instruction_list1 '}' + ; +instruction_list1 + : instruction + | instruction_list1 instruction + ; +instruction + : 'ROOT' name '=' shape opcode operands extra_attributes + | name '=' shape opcode operands extra_attributes + ; + +operands + : '(' operands1 ')' + ; +operands1 + : /*empty*/ + | operand + | operands1 ',' operand + ; +operand + : shape name + ; + +extra_attributes + : /*empty*/ + | ',' extra_attribute + | ',' extra_attribute extra_attributes + ; +extra_attribute + : attribute_name attribute_value + ; + +param_list + : '(' param_list1 ')' + ; +param_list1 + : /*empty*/ + | param + | param_list1 ',' param + ; +param + : name shape + ; + +shape + : shape_val_ + | '(' tuple_elements ')' + ; +tuple_elements + : /*empty*/ + | shape (',' shape)* + ; + +name + : identifier ':' + | '%' identifier + ; + +identifier + : [a-zA-Z_][a-zA-Z0-9_.-]* + ; + +``` diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc new file mode 100644 index 0000000000000000000000000000000000000000..fba343de482ab11dd12fdeb4fa202b50d0bcc2b5 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -0,0 +1,279 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" + +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { +namespace tools { + +using tensorflow::StringPiece; + +namespace { + +constexpr int kEOF = -1; +constexpr int kError = -2; + +// [a-zA-Z0-9_.-] +bool IsIdentifierChar(char c) { + return isalnum(static_cast(c)) || c == '-' || c == '.' || + c == '_'; +} + +} // namespace + +int HloLexer::GetNextChar() { + int current_char = PeekCurrentChar(); + if (current_char != kEOF && current_char != kError) { + current_ptr_++; + } + return current_char; +} + +int HloLexer::PeekCurrentChar() const { + if (current_ptr_ == buf_.end()) { + return kEOF; + } + char current_char = *current_ptr_; + if (current_char == 0) { + // '\0' should not appear in the middle of the string. + return kError; + } + return static_cast(current_char); +} + +bool HloLexer::CanDereference(const char* ptr) const { + return ptr < buf_.end() && ptr >= buf_.begin(); +} + +StringPiece HloLexer::StringPieceFromPointers(const char* begin, + const char* end) const { + CHECK(begin <= end); + CHECK(begin == buf_.end() || CanDereference(begin)); + CHECK(end == buf_.end() || CanDereference(end)); + return StringPiece(begin, end - begin); +} + +tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( + const char* begin, const char* end) const { + CHECK(begin <= end); + CHECK(begin == buf_.end() || CanDereference(begin)); + CHECK(end == buf_.end() || CanDereference(end)); + return tensorflow::RegexpStringPiece(begin, end - begin); +} + +TokKind HloLexer::LexToken() { + while (true) { + token_start_ = current_ptr_; + + int current_char = GetNextChar(); + switch (current_char) { + default: + // [a-zA-Z_] + if (isalpha(static_cast(current_char)) || + current_char == '_') { + return LexIdentifier(); + } + return TokKind::kError; + case kEOF: + // Hit the end of the input buffer. + return TokKind::kEof; + case kError: + // Hit an invalid character in the input buffer. + return TokKind::kError; + case ' ': + case '\t': + case '\n': + case '\r': + // Ignore whitespace. + continue; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '-': + if (current_char == '-' && PeekCurrentChar() == '>') { + current_ptr_++; + return TokKind::kArrow; + } + return LexDigitOrNegative(); + case '=': + return TokKind::kEqual; + case ',': + return TokKind::kComma; + case '%': + return LexPercent(); + case ':': + return TokKind::kColon; + case '[': + return TokKind::kLsquare; + case ']': + return TokKind::kRsquare; + case '{': + return TokKind::kLbrace; + case '}': + return TokKind::kRbrace; + case '(': + return TokKind::kLparen; + case ')': + return TokKind::kRparen; + } + } +} + +// Lex a shape, name, keyword, or opcode. +// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? +// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: +// keyword ::= HloModule, ENTRY, ... +// opcode ::= add, greater-than, ... +// attribute_name ::= condition, body, dimensions, ... +TokKind HloLexer::LexIdentifier() { + { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + // 'consumable' will be advanced iff its prefix matches the pattern. + static LazyRE2 shape_pattern = { + R"(^(\w*\d*)\[([\d,]*)\](?:\s*{([\d,]*)})?)"}; + if (RE2::Consume(&consumable, *shape_pattern)) { + auto status_or_shape = ShapeUtil::ParseShapeString( + StringPieceFromPointers(token_start_, consumable.begin())); + if (status_or_shape.ok()) { + // This is a shape string. + shape_val_ = status_or_shape.ValueOrDie(); + current_ptr_ = consumable.begin(); + return TokKind::kShape; + } + } + } + + while (IsIdentifierChar(PeekCurrentChar())) { + current_ptr_++; + } + + // If followed by ':', it's a name. + if (PeekCurrentChar() == ':') { + str_val_.assign(token_start_, current_ptr_); + current_ptr_++; // skip ':' + return TokKind::kName; + } + + // If followed by '=', it's a attribute name. + if (PeekCurrentChar() == '=') { + str_val_.assign(token_start_, current_ptr_); + current_ptr_++; // skip '=' + return TokKind::kAttributeName; + } + + StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); + + // See if this is a keyword. +#define KEYWORD(STR) \ + do { \ + if (identifier == #STR) { \ + return TokKind::kw_##STR; \ + } \ + } while (false) + + KEYWORD(true); + KEYWORD(false); + KEYWORD(HloModule); + KEYWORD(ENTRY); + KEYWORD(ROOT); + +#undef KEYWORD + + // See if this is an opcode. + auto opcode = StringToHloOpcode(identifier.ToString()); + if (opcode.ok()) { + opcode_val_ = opcode.ValueOrDie(); + return TokKind::kOpcode; + } + + current_ptr_ = token_start_ + 1; + return TokKind::kError; +} + +// Lex names after a % character. +// name ::= [a-zA-Z_][a-zA-Z0-9_.-]* +TokKind HloLexer::LexPercent() { + const char* name_start = current_ptr_; + if (isalpha(static_cast(PeekCurrentChar())) || + PeekCurrentChar() == '_') { + current_ptr_++; + while (IsIdentifierChar(PeekCurrentChar())) { + current_ptr_++; + } + str_val_.assign(name_start, current_ptr_); + return TokKind::kName; + } + return TokKind::kError; +} + +// Lex integer and floating-point values. +// int [-]?[0-9]+ +// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) +// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) +TokKind HloLexer::LexDigitOrNegative() { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 float_pattern = { + R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"}; + if (RE2::Consume(&consumable, *float_pattern)) { + current_ptr_ = consumable.begin(); + tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), + &decimal_val_); + return TokKind::kDecimal; + } + + static LazyRE2 int_pattern = {R"([-]?\d+)"}; + if (RE2::Consume(&consumable, *int_pattern)) { + current_ptr_ = consumable.begin(); + tensorflow::strings::safe_strto64( + StringPieceFromPointers(token_start_, current_ptr_), &int64_val_); + return TokKind::kInt; + } + + return TokKind::kError; +} + +StringPiece HloLexer::GetCurrentLine() const { + const char* start = token_start_; + const char* end = current_ptr_; + if (!CanDereference(start) || !CanDereference(end)) { + return "LINE OUT OF RANGE"; + } + while (start > buf_.begin() && *start != '\n') { + start--; + } + while (end < buf_.end() && *end != '\n') { + end++; + } + return StringPieceFromPointers(start, end); +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..433a3a3601e969de154d2f463f650f5f0b07a49f --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace tools { + +// Lexer for the HloModule::ToString() format text. +class HloLexer { + public: + explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { + current_ptr_ = buf_.begin(); + } + + TokKind Lex() { return current_kind_ = LexToken(); } + TokKind GetKind() const { return current_kind_; } + string GetStrVal() const { + switch (GetKind()) { + case TokKind::kName: + case TokKind::kAttributeName: + return str_val_; + default: + LOG(FATAL) << "This token does not have string value"; + } + } + Shape GetShapeVal() const { + CHECK(GetKind() == TokKind::kShape); + return shape_val_; + } + HloOpcode GetOpcodeVal() const { + CHECK(GetKind() == TokKind::kOpcode); + return opcode_val_; + } + int64 GetInt64Val() const { + CHECK(GetKind() == TokKind::kInt); + return int64_val_; + } + double GetDecimalVal() const { + CHECK(GetKind() == TokKind::kDecimal); + return decimal_val_; + } + + // Returns the line of text that is currently being lexed. + tensorflow::StringPiece GetCurrentLine() const; + + private: + // Returns the current character. If it's neither the end of input buffer nor + // an invalid character, moves the pointer forward. + int GetNextChar(); + + // Returns the current character. + int PeekCurrentChar() const; + + // Creates StringPiece with the given begin and end. Exits if the begin > end, + // or it's out of the range of the current buffer. + tensorflow::StringPiece StringPieceFromPointers(const char* begin, + const char* end) const; + tensorflow::RegexpStringPiece RegexpStringPieceFromPointers( + const char* begin, const char* end) const; + + // Returns true if the given ptr is dereferenceable within the range of the + // current buffer. + bool CanDereference(const char* ptr) const; + + TokKind LexToken(); + + TokKind LexIdentifier(); + TokKind LexPercent(); + TokKind LexShape(); + TokKind LexConstant(); + TokKind LexDigitOrNegative(); + + const tensorflow::StringPiece buf_; + const char* current_ptr_; + + // Information about the current token. + const char* token_start_; + TokKind current_kind_; + string str_val_; + Shape shape_val_; + HloOpcode opcode_val_; + int64 int64_val_; + double decimal_val_; +}; + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc new file mode 100644 index 0000000000000000000000000000000000000000..d91404d73a1d76822c629512b4eb62dfe3a73579 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -0,0 +1,702 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace tools { + +namespace { + +using tensorflow::StringPiece; +using tensorflow::strings::StrCat; + +// Parser for the HloModule::ToString() format text. +class HloParser { + public: + explicit HloParser(StringPiece str) : lexer_(str) {} + + // Runs the parser. Returns false if an error occurred. + bool Run(); + + // Returns the parsed HloModule. + std::unique_ptr ConsumeHloModule() { return std::move(module_); } + + // Returns the error information. + string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + + private: + // ParseXXX returns false if an error occurred. + bool ParseHloModule(); + bool ParseComputations(); + bool ParseComputation(); + bool ParseInstructionList(HloComputation::Builder* builder, + string* root_name); + bool ParseInstruction(HloComputation::Builder* builder, string* root_name); + bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseOperands(std::vector* operands); + // Fill parsed operands into 'operands' and expect a certain number of + // operands. + bool ParseOperands(std::vector* operands, + const int expected_size); + + template + bool ParseExtraAttribute(T* value, const string& expected_attribute); + template + bool ParseAttributeValue(T* value); + + bool ParseParamList(); + bool ParseName(string* result); + bool ParseAttributeName(string* result); + bool ParseShape(Shape* result); + bool ParseOpcode(HloOpcode* result); + bool ParseInt64(int64* result); + bool ParseDecimal(double* result); + bool ParseBool(bool* result); + bool ParseToken(TokKind kind, const string& msg); + + // Logs the current parsing line and the given message. Always returns false. + bool TokenError(StringPiece msg); + + // If the current token is 'kind', eats it (i.e. lexes the next token) and + // returns true. + bool EatIfPresent(TokKind kind); + + // Adds the instruction to the pool. Returns false and emits an error if the + // instruction already exists. + bool AddInstruction(const string& name, HloInstruction* instruction); + // Adds the computation to the pool. Returns false and emits an error if the + // computation already exists. + bool AddComputation(const string& name, HloComputation* computation); + + // The map from the instruction name to the instruction. This does not own the + // instructions. + std::unordered_map instruction_pool_; + std::unordered_map computation_pool_; + + HloLexer lexer_; + std::unique_ptr module_; + std::vector error_; +}; + +bool HloParser::TokenError(StringPiece msg) { + error_.push_back( + StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; ", msg)); + return false; +} + +bool HloParser::Run() { + lexer_.Lex(); + return ParseHloModule(); +} + +// ::= 'HloModule' name computations +bool HloParser::ParseHloModule() { + if (lexer_.GetKind() != TokKind::kw_HloModule) { + return TokenError("expects HloModule"); + } + // Eat 'HloModule' + lexer_.Lex(); + + string name; + if (!ParseName(&name)) { + return false; + } + + module_ = MakeUnique(name); + + return ParseComputations(); +} + +// computations ::= (computation)+ +bool HloParser::ParseComputations() { + do { + if (!ParseComputation()) { + return false; + } + } while (lexer_.GetKind() != TokKind::kEof); + return true; +} + +// computation ::= ('ENTRY')? name param_list '->' shape instruction_list +bool HloParser::ParseComputation() { + const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY); + string name; + if (!ParseName(&name)) { + return false; + } + auto builder = MakeUnique(name); + + Shape shape; + string root_name; + if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") || + !ParseShape(&shape) || !ParseInstructionList(builder.get(), &root_name)) { + return false; + } + + HloInstruction* root = + tensorflow::gtl::FindPtrOrNull(instruction_pool_, root_name); + // This means some instruction was marked as ROOT but we didn't find it in the + // pool, which should not happen. + if (!root_name.empty() && root == nullptr) { + LOG(FATAL) << "instruction " << root_name + << " was marked as ROOT but the parser has not seen it before"; + } + // Now root can be either an existing instruction or a nullptr. If it's a + // nullptr, the implementation of Builder will set the last instruction as + // root instruction. + HloComputation* computation = + is_entry_computation + ? module_->AddEntryComputation(builder->Build(root)) + : module_->AddEmbeddedComputation(builder->Build(root)); + return AddComputation(name, computation); +} + +// instruction_list ::= '{' instruction_list1 '}' +// instruction_list1 ::= (instruction)+ +bool HloParser::ParseInstructionList(HloComputation::Builder* builder, + string* root_name) { + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of instruction list.")) { + return false; + } + do { + if (!ParseInstruction(builder, root_name)) { + return false; + } + } while (lexer_.GetKind() != TokKind::kRbrace); + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction list."); +} + +// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)* +bool HloParser::ParseInstruction(HloComputation::Builder* builder, + string* root_name) { + string name; + Shape shape; + HloOpcode opcode; + std::vector operands; + bool is_root = EatIfPresent(TokKind::kw_ROOT); + if (!ParseName(&name) || + !ParseToken(TokKind::kEqual, "expects '=' in instruction") || + !ParseShape(&shape) || !ParseOpcode(&opcode)) { + return false; + } + if (is_root) { + *root_name = name; + } + HloInstruction* instruction; + switch (opcode) { + case HloOpcode::kParameter: { + int64 parameter_number; + if (!ParseToken(TokKind::kLparen, + "expects '(' before parameter number") || + !ParseInt64(¶meter_number) || + !ParseToken(TokKind::kRparen, "expects ')' after parameter number")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateParameter(parameter_number, shape, name)); + break; + } + case HloOpcode::kConstant: { + std::unique_ptr literal; + if (!ParseToken(TokKind::kLparen, + "expects '(' before constant literal") || + !ParseLiteral(&literal, shape) || + !ParseToken(TokKind::kRparen, "expects ')' after constant literal")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + break; + } + // Unary ops. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kBitcast: + case HloOpcode::kCeil: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kExp: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kFloor: + case HloOpcode::kLog: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kReal: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSort: + case HloOpcode::kTanh: { + if (!ParseOperands(&operands, /*expected_size=*/1)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateUnary(shape, opcode, operands[0])); + break; + } + // Binary ops. + case HloOpcode::kAdd: + case HloOpcode::kDivide: + case HloOpcode::kMultiply: + case HloOpcode::kSubtract: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: + case HloOpcode::kDot: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: { + if (!ParseOperands(&operands, /*expected_size=*/2)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateBinary( + shape, opcode, operands[0], operands[1])); + break; + } + // Ternary ops. + case HloOpcode::kClamp: + case HloOpcode::kSelect: { + if (!ParseOperands(&operands, /*expected_size=*/3)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateTernary( + shape, opcode, operands[0], operands[1], operands[2])); + break; + } + // Other supported ops. + case HloOpcode::kConvert: { + if (!ParseOperands(&operands, /*expected_size=*/1)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateConvert(shape, operands[0])); + break; + } + case HloOpcode::kCrossReplicaSum: { + if (!ParseOperands(&operands, /*expected_size=*/1)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCrossReplicaSum(shape, operands[0])); + break; + } + case HloOpcode::kReshape: { + if (!ParseOperands(&operands, /*expected_size=*/1)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateReshape(shape, operands[0])); + break; + } + case HloOpcode::kTuple: { + if (!ParseOperands(&operands)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateTuple(operands)); + break; + } + case HloOpcode::kWhile: { + HloComputation* condition; + HloComputation* body; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseExtraAttribute(&condition, + /*expected_attribute=*/"condition") || + !ParseExtraAttribute(&body, /*expected_attribute=*/"body")) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateWhile( + shape, condition, body, /*init=*/operands[0])); + break; + } + case HloOpcode::kRecv: { + int64 channel_id; + if (!ParseOperands(&operands, /*expected_size=*/0) || + !ParseExtraAttribute(&channel_id, + /*expected_attribute=*/"channel_id")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateRecv(shape, channel_id)); + break; + } + case HloOpcode::kSend: { + int64 channel_id; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseExtraAttribute(&channel_id, + /*expected_attribute=*/"channel_id")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateSend(operands[0], channel_id)); + break; + } + case HloOpcode::kGetTupleElement: { + int64 index; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseExtraAttribute(&index, /*expected_attribute=*/"index")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateGetTupleElement(shape, operands[0], index)); + break; + } + case HloOpcode::kCall: { + HloComputation* to_apply; + if (!ParseOperands(&operands) || + !ParseExtraAttribute(&to_apply, + /*expected_attribute=*/"to_apply")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCall(shape, operands, to_apply)); + break; + } + case HloOpcode::kBroadcast: + case HloOpcode::kCustomCall: + case HloOpcode::kConcatenate: + case HloOpcode::kReducePrecision: + case HloOpcode::kConvolution: + case HloOpcode::kMap: + case HloOpcode::kPad: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kReverse: + case HloOpcode::kRng: + case HloOpcode::kSlice: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kTranspose: + case HloOpcode::kFusion: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kIndex: + case HloOpcode::kTrace: + return TokenError(StrCat("parsing not yet implemented for op: ", + HloOpcodeString(opcode))); + } + // Parse "device=". + if (lexer_.GetKind() == TokKind::kComma) { + int64 device; + if (!ParseExtraAttribute(&device, /*expected_attribute=*/"device")) { + return false; + } + OpDeviceAssignment assignment; + assignment.set_has_device(true); + assignment.set_device(device); + instruction->set_device_assignment(assignment); + } + + return AddInstruction(name, instruction); +} + +bool HloParser::ParseLiteral(std::unique_ptr* literal, + const Shape& shape) { + switch (shape.element_type()) { + case PRED: + bool b; + if (!ParseBool(&b)) { + return false; + } + *literal = Literal::CreateR0(b); + return true; + case S32: + int64 i; + if (!ParseInt64(&i)) { + return false; + } + *literal = Literal::CreateR0(i); + return true; + case F32: + double d; + if (!ParseDecimal(&d)) { + return false; + } + *literal = Literal::CreateR0(d); + return true; + default: + return TokenError(StrCat("unsupported constant in shape: ", + ShapeUtil::HumanString(shape))); + } +} + +// operands ::= '(' operands1 ')' +// operands1 +// ::= /*empty*/ +// ::= operand (, operand)* +// operand ::= shape name +bool HloParser::ParseOperands(std::vector* operands) { + if (!ParseToken(TokKind::kLparen, + "expects '(' at the beginning of operands")) { + return false; + } + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + do { + Shape shape; + string name; + if (!ParseShape(&shape) || !ParseName(&name)) { + return false; + } + HloInstruction* instruction = + tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + if (!instruction) { + return TokenError(StrCat("instruction does not exist: ", name)); + } + operands->push_back(instruction); + } while (EatIfPresent(TokKind::kComma)); + } + return ParseToken(TokKind::kRparen, "expects ')' at the end of operands"); +} + +bool HloParser::ParseOperands(std::vector* operands, + const int expected_size) { + if (!ParseOperands(operands)) { + return false; + } + if (expected_size != operands->size()) { + return TokenError(StrCat("expects ", expected_size, " operands, but has ", + operands->size(), " operands")); + } + return true; +} + +// extra_attribute ::= ',' attribute_name value +template +bool HloParser::ParseExtraAttribute(T* value, + const string& expected_attribute) { + if (!ParseToken(TokKind::kComma, + "expects ',' in front of an extra attribute")) { + return false; + } + string attribute_name; + if (!ParseAttributeName(&attribute_name) && + attribute_name != expected_attribute) { + return TokenError(StrCat("expects attribute name: ", expected_attribute)); + } + if (!ParseAttributeValue(value)) { + return TokenError( + StrCat("expects value for attribute: ", expected_attribute)); + } + return true; +} + +template <> +bool HloParser::ParseAttributeValue(HloComputation** value) { + string name; + if (!ParseName(&name)) { + return TokenError("expects computation name"); + } + *value = tensorflow::gtl::FindPtrOrNull(computation_pool_, name); + if (*value == nullptr) { + return TokenError(StrCat("computation does not exist: ", name)); + } + return true; +} + +template <> +bool HloParser::ParseAttributeValue(int64* value) { + return ParseInt64(value); +} + +// param_list ::= '(' param_list1 ')' +// param_list1 +// ::= /*empty*/ +// ::= param (',' param)* +// param ::= name shape +bool HloParser::ParseParamList() { + if (!ParseToken(TokKind::kLparen, + "expects '(' at the beginning of param list")) { + return false; + } + + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + do { + Shape shape; + if (!ParseToken(TokKind::kName, "expects name in parameter") || + !ParseShape(&shape)) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); +} + +// shape ::= shape_val_ +// shape ::= '(' tuple_elements ')' +// tuple_elements +// ::= /*empty*/ +// ::= shape (',' shape)* +bool HloParser::ParseShape(Shape* result) { + if (EatIfPresent(TokKind::kLparen)) { // Tuple + std::vector shapes; + if (lexer_.GetKind() == TokKind::kRparen) { + /*empty*/ + } else { + // shape (',' shape)* + do { + shapes.emplace_back(); + if (!ParseShape(&shapes.back())) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + *result = ShapeUtil::MakeTupleShape(shapes); + return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); + } + + if (lexer_.GetKind() != TokKind::kShape) { + return TokenError("expects shape"); + } + *result = lexer_.GetShapeVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseName(string* result) { + VLOG(1) << "ParseName"; + if (lexer_.GetKind() != TokKind::kName) { + return TokenError("expects name"); + } + *result = lexer_.GetStrVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseAttributeName(string* result) { + if (lexer_.GetKind() != TokKind::kAttributeName) { + return TokenError("expects attribute name"); + } + *result = lexer_.GetStrVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseOpcode(HloOpcode* result) { + VLOG(1) << "ParseOpcode"; + if (lexer_.GetKind() != TokKind::kOpcode) { + return TokenError("expects opcode"); + } + *result = lexer_.GetOpcodeVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseInt64(int64* result) { + VLOG(1) << "ParseInt64"; + if (lexer_.GetKind() != TokKind::kInt) { + return TokenError("expects integer"); + } + *result = lexer_.GetInt64Val(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseDecimal(double* result) { + switch (lexer_.GetKind()) { + case TokKind::kDecimal: + *result = lexer_.GetDecimalVal(); + break; + case TokKind::kInt: + *result = static_cast(lexer_.GetInt64Val()); + break; + default: + return TokenError("expects decimal or integer"); + } + lexer_.Lex(); + return true; +} + +bool HloParser::ParseBool(bool* result) { + if (lexer_.GetKind() != TokKind::kw_true && + lexer_.GetKind() != TokKind::kw_false) { + return TokenError("expects true or false"); + } + *result = lexer_.GetKind() == TokKind::kw_true; + lexer_.Lex(); + return true; +} + +bool HloParser::ParseToken(TokKind kind, const string& msg) { + if (lexer_.GetKind() != kind) { + return TokenError(msg); + } + lexer_.Lex(); + return true; +} + +bool HloParser::EatIfPresent(TokKind kind) { + if (lexer_.GetKind() != kind) { + return false; + } + lexer_.Lex(); + return true; +} + +bool HloParser::AddInstruction(const string& name, + HloInstruction* instruction) { + auto result = instruction_pool_.insert({name, instruction}); + if (!result.second) { + return TokenError(StrCat("instruction already exists: ", name)); + } + return true; +} + +bool HloParser::AddComputation(const string& name, + HloComputation* computation) { + auto result = computation_pool_.insert({name, computation}); + if (!result.second) { + return TokenError(StrCat("computation already exists: ", name)); + } + return true; +} + +} // namespace + +StatusOr> Parse(StringPiece str) { + HloParser parser(str); + if (!parser.Run()) { + return InvalidArgument("Syntax error: %s", parser.GetError().c_str()); + } + return parser.ConsumeHloModule(); +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/tools/parser/hlo_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..9aaf18ef20d769cd9ac6f0e48bc92f62292ba31a --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace tools { + +// The api of the hlo parser. Given a string in the HloModule::ToString() +// format, returns the parsed HloModule. +StatusOr> Parse(tensorflow::StringPiece str); + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2bf1cce1c02146af2f495931a55cdec60959deea --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -0,0 +1,321 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +#include +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace tools { +namespace { + +struct TestData { + string test_name; + string module_string; +}; + +string TestDataToString(const ::testing::TestParamInfo& data) { + return data.param.test_name; +} + +std::vector CreateTestCases() { + // clang-format off + return std::vector({ +// ax + y +{ +"AxpyParam", +R"(HloModule axpy_module: + +ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[2,4]{1,0} parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} + +)" +}, +// pred constant +{ +"ConstantPred", +R"(HloModule constant_pred_module: + +ENTRY %constant_pred () -> pred[] { + ROOT %constant = pred[] constant(true) +} + +)" +}, +// s32 constant +{ +"ConstantS32", +R"(HloModule constant_s32_module: + +ENTRY %constant_s32 () -> s32[] { + ROOT %constant = s32[] constant(-42) +} + +)" +}, +// f32 constant, but the value is not a decimal +{ +"ConstantF32", R"(HloModule ConstantF32_module: + +ENTRY %ConstantF32.v4 () -> f32[] { + ROOT %constant = f32[] constant(42) +} + +)" +}, +// constant + constant +{ +"AddConstants", +R"(HloModule add_constants_module: + +ENTRY %add_constants () -> f32[] { + %constant = f32[] constant(3.14) + ROOT %add = f32[] add(f32[] %constant, f32[] %constant) +} + +)" +}, +// v1 > v2 ? v1 : v2 +{ +"SelectR1F32", +R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module: + +ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { + %v1 = f32[4]{0} parameter(0), device=1 + %v2 = f32[4]{0} parameter(1), device=1 + %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2) + ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2) +} + +)" +}, +// empty tuple +{ +"EmptyTupleCreate", +R"(HloModule EmptyTupleCreate_module: + +ENTRY %EmptyTupleCreate.v1 () -> () { + ROOT %tuple = () tuple() +} + +)" +}, +// tuple +{ +"TupleCreate", +R"(HloModule TupleCreate_module: + +ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) +} + +)" +}, +// int32 result = 0; +// while (result < 5) { result = result + 1; } +{ +"WhileWithScalarS32Result", +R"(HloModule WhileWithScalarS32Result_module: + +%body.v3 (prev.1: s32[]) -> s32[] { + %constant = s32[] constant(1) + %prev.1 = s32[] parameter(0) + ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1) +} + +%condition.v3 (prev.2: s32[]) -> pred[] { + %constant.1 = s32[] constant(5) + %prev.2 = s32[] parameter(0) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %prev.2) +} + +ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { + %constant.2 = s32[] constant(0) + ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3 +} + +)" +}, +// send and recv +{ +"SendRecv", +R"(HloModule TwoSendRecvBothWayRecvFist_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = f32[] recv(), channel_id=15, device=1 + ROOT %constant = f32[] constant(2.1), device=0 + %send = () send(f32[] %constant), channel_id=16, device=0 +} + +)" +}, +// get-tuple-element +{ +"GetTupleElement", +R"(HloModule GetTupleElement_module: + +ENTRY %GetTupleElement.v4 () -> s32[] { + %constant = f32[] constant(1.23) + %constant.1 = s32[] constant(4) + %tuple = (f32[], s32[]) tuple(f32[] %constant, s32[] %constant.1) + ROOT %get-tuple-element = s32[] get-tuple-element((f32[], s32[]) %tuple), index=1, device=0 +} + +)" +}, +// call +{ +"Call", +R"(HloModule CallR0F32IdentityScalar_module: + +%Identity.v1 (x: f32[]) -> f32[] { + ROOT %x = f32[] parameter(0) +} + +ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { + %constant = f32[] constant(42) + ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1 +} + +)" +} + }); + // clang-format on +} + +class HloParserTest : public ::testing::Test, + public ::testing::WithParamInterface { + protected: + void ExpectSuccess() { + const string& original = GetParam().module_string; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + EXPECT_EQ(original, result.ValueOrDie()->ToString()); + } +}; + +TEST_P(HloParserTest, Run) { ExpectSuccess(); } + +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); + +TEST_F(HloParserTest, Empty) { + const string original = ""; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, Garbage) { + const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongOpcode) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { + %x = f32[]{} parameter(0) + %y = f32[]{} parameter(1) + %le = pred[]{} le(f32[]{} %x, f32[]{} %y) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongShape) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: g32[]) -> g32[] { + %x = g32[]{} parameter(0) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongOperandsSize) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: f32[]) -> pred[] { + %x = f32[]{} parameter(0) + %eq = pred[]{} equal-to(f32[]{} %x) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, OperandNotFound) { + const string original = R"(HloModule operand_not_found: +ENTRY %blabla (x: f32[]) -> pred[] { + %x = f32[]{} parameter(0) + %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) +} +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, MoreConstants) { + const string original = R"(HloModule SelectScalarS32True_module: + +ENTRY %SelectScalarS32True.v4 () -> s32[] { + %constant.2 = pred[] constant(true) + %constant.1 = s32[] constant(-42) + %constant = s32[] constant(42) + %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) +} + +)"; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + // Constant instructions have no name. The string will be parsed successfully + // but the constant names will not be exactly the same. +} + +TEST_F(HloParserTest, ConstantWithExp) { + const string original = R"(HloModule ConstantWithExp_module: + +ENTRY %ConstantWithExp.v4 () -> f32[] { + %constant.1 = f32[] constant(3e+2) +} + +)"; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + // The string will be parsed successfully but the output strings are not + // exactly the same, because "3e2" is parsed into value 300 and will be + // printed as "300". +} + +} // namespace +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h new file mode 100644 index 0000000000000000000000000000000000000000..1d56ea347823aeec5bead6925ece6a7296b596af --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ + +namespace xla { +namespace tools { + +// Defines different kinds of tokens in a hlo module string. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + + kArrow, // -> + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_ROOT, + kw_true, + kw_false, + + // Typed tokens. + kName, // %foo + kAttributeName, // dimensions= + kShape, // f32[2,3]{1,0} + kOpcode, // add + kInt, // 42 + kDecimal, // 4.2 +}; + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index ea8b4b7b989b72034f33920a7d8c1a75e15a7dd1..3b19ca321cad35aad18f7f498e08fd744ffbc371 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_ #define TENSORFLOW_COMPILER_XLA_TYPES_H_ +#include + #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" @@ -35,6 +37,8 @@ using ::tensorflow::uint16; using ::tensorflow::uint32; using ::tensorflow::uint64; +using complex64 = std::complex; + using ::Eigen::half; } // namespace xla diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index f6c0bd1563f4d9090df94b6edd8226119194c76c..f58f57b44396c90a3820835a3d0ecc792aaa7cd0 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 22e70ec97adf9297ceb3f98f57feb17ae9dafc3d..3fa5bcc1df4f0294582b6c74735fef08c87433eb 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -17,11 +17,3 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0): protoc="@protobuf_archive//:protoc", testonly=testonly, visibility=visibility,) - -# Flags required for modules that export symbols that are to be called by the -# XLA CustomCall operator. CustomCall must be able to find symbols with dlsym(), -# which on Linux requires we link with --export-dynamic. -export_dynamic_linkopts = select({ - "//tensorflow:darwin": [], - "//conditions:default": ["-Wl,--export-dynamic"], -}) diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 4840ddb8817a37c7dabcfb27e24a2a5472f4b6a2..ce3c3eee68ad7f7ebb42836e3cae14803f8650d7 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -82,8 +82,8 @@ message DebugOptions { // Dump all HLO modules as text into the provided directory path. string xla_generate_hlo_text_to = 7; - // Dump compilation artifacts as JSON into this directory. - string xla_dump_debug_json_to = 8; + // Dump compilation artifacts in binary proto into this directory. + string xla_dump_hlo_proto_to = 8; // Instrument the computation to collect per-HLO cycle counts. bool xla_hlo_profile = 9; @@ -191,6 +191,11 @@ message ExecutionOptions { uint64 seed = 3; DebugOptions debug_options = 4; + + // This optional field specifies a particular set of devices to run the + // computation on. The computation will be partitioned across these devices. + // If not provided, the default device will be chosen. + repeated DeviceHandle device_handles = 5; } message SnapshotComputationRequest { @@ -312,12 +317,8 @@ message ExecuteRequest { ComputationHandle computation = 1; repeated GlobalDataHandle arguments = 2; - // This optional field specifies a particular device to run the computation. - // If not provided, the default device will be chosen. - DeviceHandle device_handle = 5; - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 6; + ExecutionOptions execution_options = 5; } message ExecuteParallelRequest { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 1771a3d5deaac0388c8c3ba6a2f283231ebd3572..fe47f85c1271851ecee496c7f101dfc7e58a49bd 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -48,6 +48,9 @@ enum PrimitiveType { F32 = 11; F64 = 12; + // Complex values of fixed width. + C64 = 15; // Paired F32 (real, imag), as in std::complex. + // A tuple is a polymorphic sequence; e.g. a shape that holds different // sub-shapes. They are used for things like returning multiple values from a // computation; e.g. a computation that returns weights and biases may have a @@ -305,6 +308,7 @@ message LiteralProto { repeated uint64 u64s = 7; repeated float f32s = 8; repeated double f64s = 9; + repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated LiteralProto tuple_literals = 10; bytes f16s = 11; // Note: the F16s are encoded in little endian byte order } @@ -392,13 +396,17 @@ message DynamicUpdateSliceRequest { } message ConvolutionDimensionNumbers { - // The number of the dimension that represents batch in the input - // (lhs) and output. - int64 batch_dimension = 1; + // The number of the dimension that represents batch in the input. + int64 input_batch_dimension = 7; + + // The number of the dimension that represents features in the input. + int64 input_feature_dimension = 8; + + // The number of the dimension that represents batch in the output. + int64 output_batch_dimension = 9; - // The number of the dimension that represents features in the input - // (lhs) and output. - int64 feature_dimension = 2; + // The number of the dimension that represents features in the output. + int64 output_feature_dimension = 10; // The dimension numbers for the spatial dimensions that the window // moves through in the input (lhs) and output. @@ -617,8 +625,8 @@ message WhileRequest { enum UnaryOperation { UNOP_INVALID = 0; - // Elementwise, logical negation - UNOP_LOGICAL_NOT = 1; + // Elementwise, logical negation on booleans and bitwise negation on ints. + UNOP_NOT = 1; // Elementwise, computes e^x. UNOP_EXP = 2; @@ -659,6 +667,12 @@ enum UnaryOperation { // Elementwise, rounds x to nearest integral value, rounding half-way cases // away from zero. UNOP_ROUND_NEAREST_AFZ = 14; + + // Elementwise, extract real component of complex x. + UNOP_REAL = 15; + + // Elementwise, extract real component of complex x. + UNOP_IMAG = 16; } message UnaryOpRequest { @@ -706,9 +720,19 @@ enum BinaryOperation { // Remainder operation. BINOP_REM = 17; - // Logical operators - BINOP_LOGICAL_AND = 18; - BINOP_LOGICAL_OR = 19; + // Element-wise, logical operators on booleans and bitwise operators on ints. + BINOP_AND = 18; + BINOP_OR = 19; + + BINOP_SHIFT_LEFT = 20; + BINOP_SHIFT_RIGHT_ARITHMETIC = 21; + BINOP_SHIFT_RIGHT_LOGICAL = 22; + + // Complex from real, imag. + BINOP_COMPLEX = 23; + + // Computes the 4-quadrant arctangent of the y, x input arguments. + BINOP_ATAN2 = 24; } message BinaryOpRequest { @@ -747,10 +771,6 @@ enum TernaryOperation { // true and operand1 if the predicate is false. TRIOP_SELECT = 1; - // Updates operand0 at index operand1 with value operand2 and outputs the - // updated value. - TRIOP_UPDATE = 2; - // Given a min, max and an operand returns the operand if between min and max, // else returns min if operand is less than min or max if operand is greater // than max. diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 65c966aa0330b6d50b2285c19bfa6f118cff4e47..2e9b96bb1d31f7c985df992c094784660d6e274c 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -53,6 +53,7 @@ py_library( "//tensorflow/contrib/linear_optimizer:sdca_ops_py", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", + "//tensorflow/contrib/losses:metric_learning_py", "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/meta_graph_transform", "//tensorflow/contrib/metrics:metrics_py", @@ -79,7 +80,7 @@ py_library( "//tensorflow/contrib/staging", "//tensorflow/contrib/stat_summarizer:stat_summarizer_py", "//tensorflow/contrib/stateless", - "//tensorflow/contrib/summary:summary_ops", + "//tensorflow/contrib/summary:summary", "//tensorflow/contrib/tensor_forest:init_py", "//tensorflow/contrib/tensorboard", "//tensorflow/contrib/testing:testing_py", @@ -87,8 +88,10 @@ py_library( "//tensorflow/contrib/tfprof", "//tensorflow/contrib/timeseries", "//tensorflow/contrib/tpu", + "//tensorflow/contrib/tpu:tpu_py", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", + "//tensorflow/python:util", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"]), ) @@ -104,6 +107,7 @@ cc_library( "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/nccl:nccl_kernels", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_kernels", + "//tensorflow/contrib/rnn:all_kernels", "//tensorflow/contrib/seq2seq:beam_search_ops_kernels", "//tensorflow/contrib/tensor_forest:model_ops_kernels", "//tensorflow/contrib/tensor_forest:stats_ops_kernels", @@ -125,6 +129,7 @@ cc_library( "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib", + "//tensorflow/contrib/rnn:all_ops", "//tensorflow/contrib/seq2seq:beam_search_ops_op_lib", "//tensorflow/contrib/tensor_forest:model_ops_op_lib", "//tensorflow/contrib/tensor_forest:stats_ops_op_lib", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index bf921808aa9a4694e06afcc2091b381a6fcffc49..a26fdb982c0f4d6d85b73912c194647a989d0ef6 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -77,9 +77,11 @@ from tensorflow.contrib import timeseries from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util +from tensorflow.contrib.eager.python import tfe as eager from tensorflow.contrib.ndlstm import python as ndlstm from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs +from tensorflow.contrib.summary import summary from tensorflow.python.util.lazy_loader import LazyLoader ffmpeg = LazyLoader("ffmpeg", diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index 744ae4c1f413bc1854a07ead9a3fa6bc90ed2fc1..35b9de27e7294f5911362e8ff5c8a27ea2b76c7f 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -19,9 +19,10 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/nccl:nccl_ops", + "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", ], ) diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 8e7f1791b864bb30e4592a86e637d1603be6618b..a5057da9fd43a88575813613d6ac9d17fd2b2e28 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -191,7 +191,7 @@ def _ragged_split(tensor, pieces): def _ring_permutations(num_workers, num_subchunks, gpu_perm): - """"Generate an array of device index arrays, one for for each subchunk. + """"Generate an array of device index arrays, one for each subchunk. In the basic ring reduction algorithm there are size(T)/num_devices data chunks and each device process one chunk per tick, i.e. sending @@ -762,6 +762,8 @@ def _reduce_non_singleton(input_tensors, red_f, un_op): if len(input_tensors) > 1: return red_f(input_tensors) else: + if not un_op: + return input_tensors output_tensors = [] for t in input_tensors: with ops.colocate_with(t): @@ -835,7 +837,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): def build_shuffle_then_ring(input_tensors, gather_devices, subdiv, - red_n_op, red_op, un_op): + red_n_op, red_op, un_op=None): """Construct hybrid of Shuffle within workers, Ring across workers.""" def upper_builder(tensors): return build_ring_all_reduce(tensors, len(tensors), subdiv, [0], diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 743a12b925700089b649c273c7573699ce44df9e..80e03f20362ed41b62ce118e864ffb0acb4ab50b 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -35,8 +35,8 @@ import org.tensorflow.Graph; import org.tensorflow.Operation; import org.tensorflow.Session; import org.tensorflow.Tensor; -import org.tensorflow.Tensors; import org.tensorflow.TensorFlow; +import org.tensorflow.Tensors; import org.tensorflow.types.UInt8; /** diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index 1555a3427fd5e40ca54c134a2c80f9d2c5feca36..ae3f48f1b276b1f13078e8845c4c87cf3473513f 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -69,6 +69,28 @@ tf_cc_test( ], ) +cc_library( + name = "adaptive_shared_batch_scheduler", + hdrs = ["adaptive_shared_batch_scheduler.h"], + deps = [ + ":batch_scheduler", + "//tensorflow/contrib/batching/util:periodic_function_dynamic", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "adaptive_shared_batch_scheduler_test", + srcs = ["adaptive_shared_batch_scheduler_test.cc"], + deps = [ + ":adaptive_shared_batch_scheduler", + "//tensorflow/contrib/batching/test_util:fake_clock_env", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..a0606427a526ffc67e10d12a084eabc64564e4ab --- /dev/null +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -0,0 +1,463 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/batching/batch_scheduler.h" +#include "tensorflow/contrib/batching/util/periodic_function.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace serving { +namespace internal { +template +class ASBSBatch; + +template +class ASBSQueue; +} // namespace internal + +// Shared batch scheduler designed to minimize latency. The scheduler keeps +// track of a number of queues (one per model or model version) which are +// continuously enqueuing requests. The scheduler groups the requests into +// batches which it periodically sends off for processing (see +// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler +// prioritizes batches by age (i.e. the batch's oldest request) irrespective of +// queue. The scheduler will process the oldest batch at an adjustable rate, +// regardless of batch size. The user can provide feedback to help set this rate +// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc). +// +// The rate (or rather, the corresponding period) is adjusted each time a batch +// is processed, using an exponentially weighted moving average to smooth +// potentially noisy feedback: +// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N +// period *= (1 + K * emwa_feedback) +// +// Some potential use cases: +// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing +// involves serial processing by a device, from a latency perspective it is +// desirable to keep the device evenly loaded, avoiding the need to wait for +// the device to process prior batches. +// feedback = num_pending_on_device() - desired_pending. +// CPU utilization - If the batch processing is cpu dominated, you can reap +// latency gains when underutilized by increasing the processing rate, but +// back the rate off when the load increases to avoid overload. +// feedback = cpu_rate() - desired_cpu_rate. + +template +class AdaptiveSharedBatchScheduler + : public std::enable_shared_from_this< + AdaptiveSharedBatchScheduler> { + public: + struct Options { + // The name to use for the pool of batch threads. + string thread_pool_name = {"batch_threads"}; + // Number of batch processing threads; equivalently the maximum number of + // concurrently running batches. + int64 num_batch_threads = port::NumSchedulableCPUs(); + // The environment to use (typically only overridden by test code). + Env* env = Env::Default(); + // Initial batch scheduling period in microseconds. Will be altered for + // non-zero rate_feedback. + double initial_scheduling_period_micros = 500; + // Minimum batch scheduling period in microseconds. Recommend setting this + // value greater than 0, otherwise it may take a while to recover from a + // sustained time of negative scheduling_period_feedback (which may occur + // under low load). + double min_scheduling_period_micros = 100; + // Maximum batch scheduling period in microseconds. + double max_scheduling_period_micros = 10000; + // Feedback function used to modify the scheduling period each time a batch + // is scheduled. Should return values roughly O(1), with positive values + // resulting in an increased period. + std::function scheduling_period_feedback{[] { return 0.; }}; + // To handle potentially noisy scheduling_period_feedback, the period is + // adjusted using an exponentially weighted moving average over the previous + // feedback_smoothing_batches batches. Must be greater than 0. + int64 feedback_smoothing_batches = 10; + }; + + // Ownership is shared between the caller of Create() and any queues created + // via AddQueue(). + static Status Create( + const Options& options, + std::shared_ptr>* scheduler); + + struct QueueOptions { + // Maximum size of each batch. + int max_batch_size = 1000; + // Maximum number of enqueued (i.e. non-scheduled) batches. + int max_enqueued_batches = 10; + }; + + using BatchProcessor = std::function>)>; + + // Adds queue (and its callback) to be managed by this scheduler. + Status AddQueue(const QueueOptions& options, + BatchProcessor process_batch_callback, + std::unique_ptr>* queue); + + private: + // access to AddBatch, RemoveQueue, GetEnv. + friend class internal::ASBSQueue; + + explicit AdaptiveSharedBatchScheduler(const Options& options); + + // Batch scheduling function which runs every scheduling_period_ microseconds. + void ProcessOneBatch(); + + // Notifies scheduler of non-empty batch which is eligible for processing. + void AddBatch(internal::ASBSBatch*); + + // Removes queue from scheduler. + void RemoveQueue(const internal::ASBSQueue* queue); + + Env* GetEnv() const { return options_.env; } + + const Options options_; + + struct BatchCompare { + bool operator()(const internal::ASBSBatch* a, + const internal::ASBSBatch* b); + }; + + // Collection of batches added by AddBatch, ordered by age. Owned by scheduler + // until they are released for processing. + std::priority_queue*, + std::vector*>, BatchCompare> + batches_ GUARDED_BY(mu_); + + // Unowned queues and callbacks added by AddQueue. + std::unordered_map*, BatchProcessor> + queues_and_callbacks_ GUARDED_BY(mu_); + + mutex mu_; + + // Responsible for running ProcessOneBatch. PeriodicFunction was used in order + // to check for deletion so that the thread can be shut down. + std::unique_ptr scheduling_thread_; + + // Responsible for running the batch processing callbacks. + std::unique_ptr batch_thread_pool_; + + // Time interval in microseconds between successive ProcessOneBatch calls. + double scheduling_period_; + + // Exponentially weighted moving average of + // options_.scheduling_period_feedback() evaluated in each ProcessOneBatch + // call. + double ewma_feedback_ = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler); +}; + +////////////////////////////////////////////////////////// +// Implementation details follow. API users need not read. + +namespace internal { +// Consolidates tasks into batches, passing them off to the +// AdaptiveSharedBatchScheduler for processing. +template +class ASBSQueue : public BatchScheduler { + public: + using QueueOptions = + typename AdaptiveSharedBatchScheduler::QueueOptions; + + ASBSQueue(std::shared_ptr> scheduler, + const QueueOptions& options); + + ~ASBSQueue() override; + + // Adds task to current batch. Fails if the task size is larger than the batch + // size or if the current batch is full and this queue's number of outstanding + // batches is at its maximum. + Status Schedule(std::unique_ptr* task) override; + + // Number of tasks waiting to be scheduled. + size_t NumEnqueuedTasks() const override; + + // Number of size 1 tasks which could currently be scheduled without failing. + size_t SchedulingCapacity() const override; + + // Notifies queue that a batch is about to be scheduled; the queue should not + // place any more tasks in this batch. + void ReleaseBatch(const ASBSBatch* batch); + + private: + std::shared_ptr> scheduler_; + const QueueOptions options_; + // Owned by scheduler_. + ASBSBatch* current_batch_ GUARDED_BY(mu_) = nullptr; + int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0; + int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0; + mutable mutex mu_; + TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue); +}; + +// Batch which remembers when and by whom it was created. +template +class ASBSBatch : public Batch { + public: + ASBSBatch(ASBSQueue* queue, int64 creation_time_micros) + : queue_(queue), creation_time_micros_(creation_time_micros) {} + + ~ASBSBatch() override {} + + ASBSQueue* queue() const { return queue_; } + + int64 creation_time_micros() const { return creation_time_micros_; } + + private: + ASBSQueue* queue_; + const int64 creation_time_micros_; + TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch); +}; +} // namespace internal + +// ---------------- AdaptiveSharedBatchScheduler ---------------- + +template +Status AdaptiveSharedBatchScheduler::Create( + const Options& options, + std::shared_ptr>* scheduler) { + if (options.num_batch_threads < 1) { + return errors::InvalidArgument("num_batch_threads must be positive; was ", + options.num_batch_threads); + } + if (options.min_scheduling_period_micros < 0) { + return errors::InvalidArgument( + "min_scheduling_period_micros must be >= 0; was ", + options.min_scheduling_period_micros); + } + if (options.min_scheduling_period_micros > + options.initial_scheduling_period_micros) { + return errors::InvalidArgument( + "initial_scheduling_period_micros (", + options.initial_scheduling_period_micros, + ") must be >= min_scheduling_period_micros (", + options.min_scheduling_period_micros, ")"); + } + if (options.initial_scheduling_period_micros > + options.max_scheduling_period_micros) { + return errors::InvalidArgument( + "initial_scheduling_period_micros (", + options.initial_scheduling_period_micros, + ") must be <= max_scheduling_period_micros (", + options.max_scheduling_period_micros, ")"); + } + if (options.feedback_smoothing_batches < 1) { + return errors::InvalidArgument( + "feedback_smoothing_batches must be positive; was ", + options.feedback_smoothing_batches); + } + scheduler->reset(new AdaptiveSharedBatchScheduler(options)); + return Status::OK(); +} + +template +AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( + const Options& options) + : options_(options), + scheduling_period_(options.initial_scheduling_period_micros) { + PeriodicFunction::Options opts; + opts.thread_name_prefix = "scheduling_thread"; + opts.env = GetEnv(); + scheduling_thread_.reset( + new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts)); + batch_thread_pool_.reset(new thread::ThreadPool( + GetEnv(), options.thread_pool_name, options.num_batch_threads)); +} + +template +Status AdaptiveSharedBatchScheduler::AddQueue( + const QueueOptions& options, BatchProcessor process_batch_callback, + std::unique_ptr>* queue) { + if (options.max_batch_size <= 0) { + return errors::InvalidArgument("max_batch_size must be positive; was ", + options.max_batch_size); + } + if (options.max_enqueued_batches <= 0) { + return errors::InvalidArgument( + "max_enqueued_batches must be positive; was ", + options.max_enqueued_batches); + } + internal::ASBSQueue* asbs_queue_raw; + queue->reset(asbs_queue_raw = new internal::ASBSQueue( + this->shared_from_this(), options)); + mutex_lock l(mu_); + queues_and_callbacks_[asbs_queue_raw] = process_batch_callback; + return Status::OK(); +} + +template +void AdaptiveSharedBatchScheduler::AddBatch( + internal::ASBSBatch* batch) { + mutex_lock l(mu_); + batches_.push(batch); +} + +template +void AdaptiveSharedBatchScheduler::RemoveQueue( + const internal::ASBSQueue* queue) { + mutex_lock l(mu_); + queues_and_callbacks_.erase(queue); +} + +template +void AdaptiveSharedBatchScheduler::ProcessOneBatch() { + static const double kFeedbackMultiplier = .001; + internal::ASBSBatch* batch = nullptr; + BatchProcessor callback; + const int64 start_time_micros = GetEnv()->NowMicros(); + { + mutex_lock l(mu_); + if (!batches_.empty()) { + batch = batches_.top(); + batches_.pop(); + callback = queues_and_callbacks_[batch->queue()]; + } + } + if (batch != nullptr) { + double feedback = options_.scheduling_period_feedback(); + const int64 N = options_.feedback_smoothing_batches; + ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N; + scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_); + if (scheduling_period_ < options_.min_scheduling_period_micros) { + scheduling_period_ = options_.min_scheduling_period_micros; + } else if (scheduling_period_ > options_.max_scheduling_period_micros) { + scheduling_period_ = options_.max_scheduling_period_micros; + } + // Queue may destroy itself after ReleaseBatch is called. + batch->queue()->ReleaseBatch(batch); + batch_thread_pool_->Schedule([callback, batch] { + callback(std::unique_ptr>(batch)); + }); + } + const int64 sleep_time = + scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros); + if (sleep_time > 0) { + GetEnv()->SleepForMicroseconds(sleep_time); + } +} + +template +bool AdaptiveSharedBatchScheduler::BatchCompare::operator()( + const internal::ASBSBatch* a, + const internal::ASBSBatch* b) { + return a->creation_time_micros() > b->creation_time_micros(); +} + +// ---------------- ASBSQueue ---------------- + +namespace internal { +template +ASBSQueue::ASBSQueue( + std::shared_ptr> scheduler, + const QueueOptions& options) + : scheduler_(scheduler), options_(options) {} + +template +ASBSQueue::~ASBSQueue() { + // Wait until last batch has been scheduled. + const int kSleepMicros = 1000; + for (;;) { + { + mutex_lock l(mu_); + if (num_enqueued_batches_ == 0) { + break; + } + } + scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros); + } + scheduler_->RemoveQueue(this); +} + +template +Status ASBSQueue::Schedule(std::unique_ptr* task) { + bool added_new_batch = false; + size_t size = (*task)->size(); + if (size > options_.max_batch_size) { + return errors::InvalidArgument("Task size ", size, + " is larger than maximum batch size ", + options_.max_batch_size); + } + { + mutex_lock l(mu_); + // Current batch is full, create another if allowed. + if (current_batch_ && + current_batch_->size() + size > options_.max_batch_size) { + if (num_enqueued_batches_ >= options_.max_enqueued_batches) { + return errors::Unavailable("The batch scheduling queue is full"); + } + current_batch_->Close(); + current_batch_ = nullptr; + } + if (!current_batch_) { + added_new_batch = true; + num_enqueued_batches_++; + current_batch_ = + new ASBSBatch(this, scheduler_->GetEnv()->NowMicros()); + } + current_batch_->AddTask(std::move(*task)); + num_enqueued_tasks_++; + } + if (added_new_batch) scheduler_->AddBatch(current_batch_); + return Status::OK(); +} + +template +void ASBSQueue::ReleaseBatch(const ASBSBatch* batch) { + mutex_lock l(mu_); + num_enqueued_batches_--; + num_enqueued_tasks_ -= batch->num_tasks(); + if (batch == current_batch_) { + current_batch_->Close(); + current_batch_ = nullptr; + } +} + +template +size_t ASBSQueue::NumEnqueuedTasks() const { + mutex_lock l(mu_); + return num_enqueued_tasks_; +} + +template +size_t ASBSQueue::SchedulingCapacity() const { + mutex_lock l(mu_); + const int current_batch_capacity = + current_batch_ ? options_.max_batch_size - current_batch_->size() : 0; + const int spare_batches = + options_.max_enqueued_batches - num_enqueued_batches_; + return spare_batches * options_.max_batch_size + current_batch_capacity; +} +} // namespace internal +} // namespace serving +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a07cd6d834fa28904bf7748b16972cca217503c1 --- /dev/null +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc @@ -0,0 +1,438 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h" + +#include "tensorflow/contrib/batching/test_util/fake_clock_env.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace serving { +namespace anonymous { + +class FakeTask : public BatchTask { + public: + explicit FakeTask(size_t size) : size_(size) {} + + ~FakeTask() override = default; + + size_t size() const override { return size_; } + + private: + const size_t size_; + + TF_DISALLOW_COPY_AND_ASSIGN(FakeTask); +}; + +// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on +// that task. Returns the resulting status. +Status ScheduleTask(size_t task_size, BatchScheduler* scheduler) { + std::unique_ptr task(new FakeTask(task_size)); + Status status = scheduler->Schedule(&task); + // Schedule() should have consumed 'task' iff it returned Status::OK. + CHECK_EQ(status.ok(), task == nullptr); + return status; +} + +// Creates a thread that waits on 'start' and then advances the fake clock in +// 'env' in a loop until 'stop' is notified. Useful for allowing objects that +// use the clock to be destroyed. +std::unique_ptr CreateFakeClockAdvancerThread( + test_util::FakeClockEnv* env, Notification* start, Notification* stop) { + return std::unique_ptr(Env::Default()->StartThread( + {}, "FakeClockAdvancerThread", [env, start, stop] { + start->WaitForNotification(); + while (!stop->HasBeenNotified()) { + env->AdvanceByMicroseconds(10); + Env::Default()->SleepForMicroseconds(10); + } + })); +} + +TEST(AdaptiveSharedBatchSchedulerTest, Basic) { + for (const bool delete_scheduler_early : {false, true}) { + for (const bool delete_queue_1_early : {false, true}) { + int queue_0_tasks = 0; + auto queue_0_callback = + [&queue_0_tasks](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_0_tasks += batch->task(i).size(); + } + }; + int queue_1_tasks = 0; + auto queue_1_callback = + [&queue_1_tasks](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_1_tasks += batch->task(i).size(); + } + }; + { + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create({}, &scheduler)); + + // Create two queues. + std::unique_ptr> queue_0; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_0_callback, &queue_0)); + std::unique_ptr> queue_1; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_1_callback, &queue_1)); + + if (delete_scheduler_early) { + // Delete our copy of the scheduler. The queues should keep it alive + // under the covers. + scheduler = nullptr; + } + // Submit tasks to the two queues, and (optionally) remove the queues. + TF_ASSERT_OK(ScheduleTask(1, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(2, queue_1.get())); + TF_ASSERT_OK(ScheduleTask(3, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(4, queue_1.get())); + if (delete_queue_1_early) { + queue_1 = nullptr; + } + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + } + EXPECT_EQ(queue_0_tasks, 9); + EXPECT_EQ(queue_1_tasks, 6); + } + } +} + +TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) { + using Scheduler = AdaptiveSharedBatchScheduler; + std::shared_ptr scheduler; + Scheduler::Options options; + options.num_batch_threads = 0; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 50; + options.max_scheduling_period_micros = 100; + options.initial_scheduling_period_micros = 1; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 50; + options.max_scheduling_period_micros = 100; + options.initial_scheduling_period_micros = 1000; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 100; + options.max_scheduling_period_micros = 50; + options.initial_scheduling_period_micros = 75; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.feedback_smoothing_batches = 0; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); +} + +TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue_0; + std::unique_ptr> queue_1; + int queue_0_tasks = 0; + int queue_1_tasks = 0; + auto queue_0_callback = [&queue_0_tasks, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_0_tasks += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + auto queue_1_callback = [&queue_1_tasks, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_1_tasks += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + AdaptiveSharedBatchScheduler::QueueOptions queue_options; + queue_options.max_batch_size = 10; + queue_options.max_enqueued_batches = 0; + // Queue must have max_enqueued_batchs > 1. + EXPECT_FALSE( + scheduler->AddQueue(queue_options, queue_0_callback, &queue_0).ok()); + queue_options.max_enqueued_batches = 2; + TF_ASSERT_OK( + scheduler->AddQueue(queue_options, queue_0_callback, &queue_0)); + queue_options.max_batch_size = 0; + // Queue must have max_batch_size > 0. + EXPECT_FALSE( + scheduler->AddQueue(queue_options, queue_1_callback, &queue_1).ok()); + queue_options.max_batch_size = 2; + queue_options.max_enqueued_batches = 1; + TF_ASSERT_OK( + scheduler->AddQueue(queue_options, queue_1_callback, &queue_1)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Task larger than max_batch_size shouldn't schedule. + EXPECT_FALSE(ScheduleTask(15, queue_0.get()).ok()); + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + env.AdvanceByMicroseconds(1); + + // Task larger than max_batch_size shouldn't schedule. + EXPECT_FALSE(ScheduleTask(3, queue_1.get()).ok()); + TF_ASSERT_OK(ScheduleTask(1, queue_1.get())); + TF_ASSERT_OK(ScheduleTask(1, queue_1.get())); + env.AdvanceByMicroseconds(1); + // Exceeds max_enqueued_batches, shouldn't schedule. + EXPECT_FALSE(ScheduleTask(1, queue_1.get()).ok()); + + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + // Exceeds max_enqueued_batches, shouldn't schedule. + EXPECT_FALSE(ScheduleTask(6, queue_0.get()).ok()); + TF_ASSERT_OK(ScheduleTask(4, queue_0.get())); + + // Batches should be processed in order from oldest to newest. + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 10); + EXPECT_EQ(queue_1_tasks, 0); + + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 10); + EXPECT_EQ(queue_1_tasks, 2); + + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 19); + EXPECT_EQ(queue_1_tasks, 2); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, RateFeedback) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + double feedback = 0; + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.min_scheduling_period_micros = 200; + options.max_scheduling_period_micros = 2000; + options.env = &env; + options.scheduling_period_feedback = [&feedback] { return feedback; }; + options.feedback_smoothing_batches = 1; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 6 batches. + for (int i = 0; i < 6; i++) { + TF_ASSERT_OK(ScheduleTask(900 + i, queue.get())); + env.AdvanceByMicroseconds(1); + } + feedback = -500; + env.AdvanceByMicroseconds(994); + env.BlockUntilThreadsAsleep(2); // scheduling period = 500 usec. + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(500); + env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec. + EXPECT_EQ(scheduled_items, 901); + feedback = 0; + env.AdvanceByMicroseconds(250); + env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec. + EXPECT_EQ(scheduled_items, 902); + feedback = 10000; // large feedback should hit max_scheduling_period. + env.AdvanceByMicroseconds(250); + env.BlockUntilThreadsAsleep(2); // scheduling period = 2000 usec. + EXPECT_EQ(scheduled_items, 903); + feedback = -10000; // large feedback should hit min_scheduling_period. + env.AdvanceByMicroseconds(1999); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 903); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); // scheduling period = 200 usec. + EXPECT_EQ(scheduled_items, 904); + env.AdvanceByMicroseconds(200); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 905); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, FeedbackSmoothing) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + double feedback = 0; + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + options.scheduling_period_feedback = [&feedback] { return feedback; }; + options.feedback_smoothing_batches = 3; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 4 batches. + for (int i = 0; i < 4; i++) { + TF_ASSERT_OK(ScheduleTask(900 + i, queue.get())); + env.AdvanceByMicroseconds(1); + } + feedback = -300; + env.AdvanceByMicroseconds(996); + env.BlockUntilThreadsAsleep(2); + // ewma_feedback = 100, scheduling_period = 900. + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(899); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + // ewma_feedback = 167, scheduling_period = 750. + EXPECT_EQ(scheduled_items, 901); + env.AdvanceByMicroseconds(749); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 901); + feedback = 1000 / 3.; + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + // emwa_feedback = 0, scheduling_period = 750. + EXPECT_EQ(scheduled_items, 902); + env.AdvanceByMicroseconds(749); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 902); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 903); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + AdaptiveSharedBatchScheduler::QueueOptions queue_options; + queue_options.max_batch_size = 10; + queue_options.max_enqueued_batches = 10; + TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 3 tasks. + EXPECT_EQ(queue->NumEnqueuedTasks(), 0); + EXPECT_EQ(queue->SchedulingCapacity(), 100); + TF_ASSERT_OK(ScheduleTask(5, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 1); + EXPECT_EQ(queue->SchedulingCapacity(), 95); + env.AdvanceByMicroseconds(1); + TF_ASSERT_OK(ScheduleTask(6, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 2); + EXPECT_EQ(queue->SchedulingCapacity(), 84); + env.AdvanceByMicroseconds(1); + TF_ASSERT_OK(ScheduleTask(1, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 3); + EXPECT_EQ(queue->SchedulingCapacity(), 83); + + env.AdvanceByMicroseconds(998); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 5); + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 7); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} +} // namespace anonymous +} // namespace serving +} // namespace tensorflow diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index 7c41ad88180badd37398f5bae057dcd0006922c3..a5072f439abad3c5db79a514a7f2baff0b021b39 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -78,7 +78,7 @@ template class Batch { public: Batch() = default; - ~Batch(); // Blocks until the batch is closed. + virtual ~Batch(); // Blocks until the batch is closed. // Appends 'task' to the batch. After calling AddTask(), the newly-added task // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1). diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 324e519a6dbfac859d386576578f7989db0cc3c5..8bb742d289a0836378a9a03c90d46293cfbfe75b 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -20,8 +20,9 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:functional_ops", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", @@ -31,7 +32,6 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", "//tensorflow/python/ops/distributions", "//third_party/py/numpy", "@six_archive//:six", diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 726a8f692f5b6eb8392efadf136aecd890d7f5eb..66a04d42e93331de74b6f3d41f83f071115c1097 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -68,6 +68,10 @@ py_library( srcs = ["python/utils/losses.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:nn", ], ) @@ -77,12 +81,13 @@ py_test( size = "small", srcs = ["python/utils/losses_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ ":losses", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", "//third_party/py/numpy", ], ) @@ -94,13 +99,30 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":gen_model_ops_py", "//tensorflow/contrib/boosted_trees:batch_ops_utils_py", "//tensorflow/contrib/boosted_trees:boosted_trees_ops_py", "//tensorflow/contrib/boosted_trees/lib:categorical_split_handler", "//tensorflow/contrib/boosted_trees/lib:ordinal_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", "//tensorflow/contrib/stateless", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:summary", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/feature_column", ], ) @@ -110,16 +132,24 @@ py_test( srcs = ["python/training/functions/gbdt_batch_test.py"], srcs_version = "PY2AND3", tags = [ - "nomac", # b/63258195 "notsan", # b/62863147 ], deps = [ ":gbdt_batch", ":losses", + ":model_ops_py", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/learn", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_test_lib", - "//third_party/py/numpy", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variables", ], ) @@ -130,16 +160,11 @@ py_test( size = "small", srcs = ["python/kernel_tests/model_ops_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ ":model_ops_py", ":prediction_ops_py", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", @@ -155,9 +180,6 @@ py_test( size = "small", srcs = ["python/kernel_tests/prediction_ops_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ ":model_ops_py", ":prediction_ops_py", @@ -175,12 +197,12 @@ py_test( size = "small", srcs = ["python/kernel_tests/quantile_ops_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ ":quantile_ops_py", "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", @@ -212,9 +234,6 @@ py_test( size = "small", srcs = ["python/kernel_tests/stats_accumulator_ops_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ ":stats_accumulator_ops_py", "//tensorflow/python:framework_ops", @@ -229,11 +248,7 @@ py_test( size = "small", srcs = ["python/kernel_tests/training_ops_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ - ":boosted_trees_ops_loader", ":model_ops_py", ":training_ops_py", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", @@ -268,9 +283,8 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:resources", + "//tensorflow/python:errors", + "//tensorflow/python:platform", ], ) @@ -309,21 +323,17 @@ tf_custom_op_py_library( deps = [ ":boosted_trees_ops_loader", ":gen_model_ops_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:resources", + "//tensorflow/python:training", ], ) tf_kernel_library( name = "model_ops_kernels", - srcs = [ - "kernels/model_ops.cc", - ], + srcs = ["kernels/model_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", - "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", @@ -387,23 +397,17 @@ tf_custom_op_py_library( deps = [ ":boosted_trees_ops_loader", ":gen_split_handler_ops_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework_for_generated_wrappers", ], ) tf_kernel_library( name = "split_handler_ops_kernels", - srcs = [ - "kernels/split_handler_ops.cc", - ], + srcs = ["kernels/split_handler_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:feature-column-handlers", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", ], alwayslink = 1, ) @@ -435,25 +439,21 @@ tf_custom_op_py_library( deps = [ ":boosted_trees_ops_loader", ":gen_training_ops_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework_for_generated_wrappers", ], ) tf_kernel_library( name = "training_ops_kernels", - srcs = [ - "kernels/training_ops.cc", - ], + srcs = ["kernels/training_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", + "//tensorflow/contrib/boosted_trees/lib:weighted_quantiles", "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", + "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc", - "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", - "//tensorflow/core:framework", + "//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource", "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", ], alwayslink = 1, ) @@ -490,9 +490,7 @@ tf_custom_op_py_library( tf_kernel_library( name = "prediction_ops_kernels", - srcs = [ - "kernels/prediction_ops.cc", - ], + srcs = ["kernels/prediction_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:example_partitioner", "//tensorflow/contrib/boosted_trees/lib:models", @@ -500,7 +498,6 @@ tf_kernel_library( "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", ], @@ -532,26 +529,22 @@ tf_custom_op_py_library( ":batch_ops_utils_py", ":boosted_trees_ops_loader", ":gen_quantile_ops_py_wrap", - "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", ], ) tf_kernel_library( name = "quantile_ops_kernels", - srcs = [ - "kernels/quantile_ops.cc", - ], + srcs = ["kernels/quantile_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", "//tensorflow/contrib/boosted_trees/lib:weighted_quantiles", "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc", "//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource", "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", ], alwayslink = 1, ) @@ -581,8 +574,6 @@ tf_custom_op_py_library( ":batch_ops_utils_py", ":boosted_trees_ops_loader", ":gen_stats_accumulator_ops_py_wrap", - "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:resources", "//tensorflow/python:training", @@ -591,13 +582,10 @@ tf_custom_op_py_library( tf_kernel_library( name = "stats_accumulator_ops_kernels", - srcs = [ - "kernels/stats_accumulator_ops.cc", - ], + srcs = ["kernels/stats_accumulator_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", "//tensorflow/contrib/boosted_trees/resources:stamped_resource", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", ], alwayslink = 1, @@ -609,7 +597,12 @@ py_library( name = "boosted_trees_pip", deps = [ ":init_py", + "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/boosted_trees/estimator_batch:custom_export_strategy", "//tensorflow/contrib/boosted_trees/estimator_batch:init_py", + "//tensorflow/contrib/boosted_trees/estimator_batch:trainer_hooks", + "//tensorflow/contrib/boosted_trees/lib:categorical_split_handler", + "//tensorflow/contrib/boosted_trees/lib:ordinal_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_py", diff --git a/tensorflow/contrib/boosted_trees/README.md b/tensorflow/contrib/boosted_trees/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7d30032e539fb16e27f48ea101094fa4d3e9171d --- /dev/null +++ b/tensorflow/contrib/boosted_trees/README.md @@ -0,0 +1,11 @@ +# TF Boosted Trees (TFBT) + +TF Boosted trees is an implementation of a gradient boosting algorithm with +trees used as weak learners. + +## Examples +Folder "examples" demonstrates how TFBT estimators can be used for various +problems. Namely, it contains: +* binary_mnist.py - an example on how to use TFBT for binary classification. +* mnist.py - a multiclass example. +* boston.py - a regression example. \ No newline at end of file diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index f9e186788f6832b292a690d8d7b04e2f4edd584e..7792c7127c0285dc2eb5b213da054674f6a81d64 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -27,13 +27,6 @@ py_library( "__init__.py", ], srcs_version = "PY2AND3", - deps = [ - "custom_export_strategy", - ":custom_loss_head", - ":estimator", - ":model", - ":trainer_hooks", - ], ) py_library( @@ -41,7 +34,12 @@ py_library( srcs = ["model.py"], srcs_version = "PY2AND3", deps = [ + ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/boosted_trees:model_ops_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", ], ) @@ -51,6 +49,10 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/learn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", ], ) @@ -61,6 +63,15 @@ py_test( srcs_version = "PY2AND3", deps = [ ":trainer_hooks", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", ], ) @@ -69,6 +80,10 @@ py_library( srcs = ["custom_loss_head.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/learn", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", ], ) @@ -82,6 +97,11 @@ py_library( "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_py", "//tensorflow/contrib/decision_trees/proto:generic_tree_model_py", "//tensorflow/contrib/learn", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", ], ) @@ -92,8 +112,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":custom_export_strategy", - "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_py", - "//tensorflow/contrib/decision_trees/proto:generic_tree_model_py", + "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", ], ) @@ -103,6 +124,8 @@ py_library( srcs_version = "PY2AND3", deps = [ ":model", - ":trainer_hooks", + "//tensorflow/contrib/boosted_trees:losses", + "//tensorflow/contrib/learn", + "//tensorflow/python:math_ops", ], ) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 7773125c16772fe37369b11532c7f42df3ce166f..ef8dee91b6cc05c4c3dd5eb3c81de4fb65b473e3 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -96,7 +96,8 @@ def make_custom_export_strategy(name, def convert_to_universal_format(dtec, sorted_feature_names, num_dense, num_sparse_float, - num_sparse_int): + num_sparse_int, + feature_name_to_proto=None): """Convert GTFlow trees to universal format.""" del num_sparse_int # unused. model_and_features = generic_tree_model_pb2.ModelAndFeatures() @@ -104,7 +105,11 @@ def convert_to_universal_format(dtec, sorted_feature_names, # feature is processed before it's fed to the model (e.g. bucketing # information). As of now, this serves as a list of features the model uses. for feature_name in sorted_feature_names: - model_and_features.features[feature_name].SetInParent() + if not feature_name_to_proto: + model_and_features.features[feature_name].SetInParent() + else: + model_and_features.features[feature_name].CopyFrom( + feature_name_to_proto[feature_name]) model = model_and_features.model model.ensemble.summation_combination_technique.SetInParent() for tree_idx in range(len(dtec.trees)): @@ -144,6 +149,8 @@ def convert_to_universal_format(dtec, sorted_feature_names, split = gtflow_node.sparse_float_binary_split_default_left.split node.default_direction = ( generic_tree_model_pb2.BinaryNode.LEFT) + # TODO(nponomareva): adjust this id assignement when we allow multi- + # column sparse tensors. feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test inequality_test.feature_id.id.value = sorted_feature_names[feature_id] @@ -154,6 +161,8 @@ def convert_to_universal_format(dtec, sorted_feature_names, split = gtflow_node.sparse_float_binary_split_default_right.split node.default_direction = ( generic_tree_model_pb2.BinaryNode.RIGHT) + # TODO(nponomareva): adjust this id assignement when we allow multi- + # column sparse tensors. feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test inequality_test.feature_id.id.value = sorted_feature_names[feature_id] diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index f8028acbdb0be44b7fd81b96b04b6e24d9060aa6..01752416b347dd0a5e646283b6b5572592df4690 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -19,8 +19,10 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.boosted_trees.estimator_batch import model +from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.ops import math_ops class GradientBoostedDecisionTreeClassifier(estimator.Estimator): @@ -65,10 +67,21 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): Raises: ValueError: If learner_config is not valid. """ + if n_classes > 2: + # For multi-class classification, use our loss implementation that + # supports second order derivative. + def loss_fn(labels, logits, weights=None): + result = losses.per_example_maxent_loss( + labels=labels, logits=logits, weights=weights, + num_classes=n_classes) + return math_ops.reduce_mean(result[0]) + else: + loss_fn = None head = head_lib.multi_class_head( n_classes=n_classes, weight_column_name=weight_column_name, - enable_centered_bias=False) + enable_centered_bias=False, + loss_fn=loss_fn) if learner_config.num_classes == 0: learner_config.num_classes = n_classes elif learner_config.num_classes != n_classes: diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 8cda5c8f2b14f2ec3cfe3702e38b81803dd075f7..c6455a7ea3d18eb358edee034cee58b2bed21024 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -93,7 +93,7 @@ def model_builder(features, labels, mode, params, config): learner_config=learner_config, feature_columns=feature_columns, logits_dimension=head.logits_dimension, - features=features) + features=training_features) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] diff --git a/tensorflow/contrib/boosted_trees/examples/binary_mnist.py b/tensorflow/contrib/boosted_trees/examples/binary_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..47ee3d816f41e44f3a2458cf537d4f7dccf7b614 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/examples/binary_mnist.py @@ -0,0 +1,169 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Demonstrates multiclass MNIST TF Boosted trees example. + + This example demonstrates how to run experiments with TF Boosted Trees on + a binary dataset. We use digits 4 and 9 from the original MNIST dataset. + + Example Usage: + python tensorflow/contrib/boosted_trees/examples/binary_mnist.py \ + --output_dir="/tmp/binary_mnist" --depth=4 --learning_rate=0.3 \ + --batch_size=10761 --examples_per_layer=10761 --eval_batch_size=1030 \ + --num_eval_steps=1 --num_trees=10 --l2=1 --vmodule=training_ops=1 + + When training is done, accuracy on eval data is reported. Point tensorboard + to the directory for the run to see how the training progresses: + + tensorboard --logdir=/tmp/binary_mnist + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.learn import learn_runner + + +def get_input_fn(data, + batch_size, + capacity=10000, + min_after_dequeue=3000): + """Input function over MNIST data.""" + # Keep only 4 and 9 digits. + ids = np.where((data.labels == 4) | (data.labels == 9)) + images = data.images[ids] + labels = data.labels[ids] + # Make digit 4 label 1, 9 is 0. + labels = labels == 4 + + def _input_fn(): + """Prepare features and labels.""" + images_batch, labels_batch = tf.train.shuffle_batch( + tensors=[images, + labels.astype(np.int32)], + batch_size=batch_size, + capacity=capacity, + min_after_dequeue=min_after_dequeue, + enqueue_many=True, + num_threads=4) + features_map = {"images": images_batch} + return features_map, labels_batch + + return _input_fn + + +# Main config - creates a TF Boosted Trees Estimator based on flags. +def _get_tfbt(output_dir): + """Configures TF Boosted Trees estimator based on flags.""" + learner_config = learner_pb2.LearnerConfig() + + learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate + learner_config.regularization.l1 = 0.0 + learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer + learner_config.constraints.max_tree_depth = FLAGS.depth + + growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER + learner_config.growing_mode = growing_mode + run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) + + # Create a TF Boosted trees estimator that can take in custom loss. + estimator = GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + examples_per_layer=FLAGS.examples_per_layer, + model_dir=output_dir, + num_trees=FLAGS.num_trees, + center_bias=False, + config=run_config) + return estimator + + +def _make_experiment_fn(output_dir): + """Creates experiment for gradient boosted decision trees.""" + data = tf.contrib.learn.datasets.mnist.load_mnist() + train_input_fn = get_input_fn(data.train, FLAGS.batch_size) + eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size) + + return tf.contrib.learn.Experiment( + estimator=_get_tfbt(output_dir), + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + train_steps=None, + eval_steps=FLAGS.num_eval_steps, + eval_metrics=None) + + +def main(unused_argv): + learn_runner.run( + experiment_fn=_make_experiment_fn, + output_dir=FLAGS.output_dir, + schedule="train_and_evaluate") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + parser = argparse.ArgumentParser() + # Define the list of flags that users can change. + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Choose the dir for the output.") + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="The batch size for reading data.") + parser.add_argument( + "--eval_batch_size", + type=int, + default=1000, + help="Size of the batch for eval.") + parser.add_argument( + "--num_eval_steps", + type=int, + default=1, + help="The number of steps to run evaluation for.") + # Flags for gradient boosted trees config. + parser.add_argument( + "--depth", type=int, default=4, help="Maximum depth of weak learners.") + parser.add_argument( + "--l2", type=float, default=1.0, help="l2 regularization per batch.") + parser.add_argument( + "--learning_rate", + type=float, + default=0.1, + help="Learning rate (shrinkage weight) with which each new tree is added." + ) + parser.add_argument( + "--examples_per_layer", + type=int, + default=1000, + help="Number of examples to accumulate stats for per layer.") + parser.add_argument( + "--num_trees", + type=int, + default=None, + required=True, + help="Number of trees to grow before stopping.") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0a3c4912b82aba88e2f8f1b97a227c894ee2ae --- /dev/null +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -0,0 +1,153 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Demonstrates a regression on Boston housing data. + + This example demonstrates how to run experiments with TF Boosted Trees on + a regression dataset. We split all the data into 20% test and 80% train, + and are using l2 loss and l2 regularization. + + Example Usage: + + python tensorflow/contrib/boosted_trees/examples/boston.py \ + --batch_size=404 --output_dir="/tmp/boston" --depth=4 --learning_rate=0.1 \ + --num_eval_steps=1 --num_trees=500 --l2=4 \ + --vmodule=training_ops=1 + + When training is done, mean squared error on eval data is reported. + Point tensorboard to the directory for the run to see how the training + progresses: + + tensorboard --logdir=/tmp/boston + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +import tensorflow as tf +from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeRegressor +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.layers.python.layers import feature_column +from tensorflow.contrib.learn import learn_runner + +_BOSTON_NUM_FEATURES = 13 + + +# Main config - creates a TF Boosted Trees Estimator based on flags. +def _get_tfbt(output_dir, feature_cols): + """Configures TF Boosted Trees estimator based on flags.""" + learner_config = learner_pb2.LearnerConfig() + + learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate + learner_config.regularization.l1 = 0.0 + # Set the regularization per instance in such a way that + # regularization for the full training data is equal to l2 flag. + learner_config.regularization.l2 = FLAGS.l2 / FLAGS.batch_size + learner_config.constraints.max_tree_depth = FLAGS.depth + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + + run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) + + # Create a TF Boosted trees regression estimator. + estimator = GradientBoostedDecisionTreeRegressor( + learner_config=learner_config, + # For the WHOLE_TREE strategy, set the examples_per_layer to be equal to + # batch size. + examples_per_layer=FLAGS.batch_size, + feature_columns=feature_cols, + label_dimension=1, + model_dir=output_dir, + num_trees=FLAGS.num_trees, + center_bias=False, + config=run_config) + return estimator + + +def _make_experiment_fn(output_dir): + """Creates experiment for gradient boosted decision trees.""" + (x_train, y_train), (x_test, + y_test) = tf.keras.datasets.boston_housing.load_data() + + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_train}, + y=y_train, + batch_size=FLAGS.batch_size, + num_epochs=None, + shuffle=True) + + eval_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) + + feature_columns = [ + feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES) + ] + + return tf.contrib.learn.Experiment( + estimator=_get_tfbt(output_dir, feature_columns), + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + train_steps=None, + eval_steps=FLAGS.num_eval_steps, + eval_metrics=None) + + +def main(unused_argv): + learn_runner.run( + experiment_fn=_make_experiment_fn, + output_dir=FLAGS.output_dir, + schedule="train_and_evaluate") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + parser = argparse.ArgumentParser() + # Define the list of flags that users can change. + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="The batch size for reading data.") + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Choose the dir for the output.") + parser.add_argument( + "--num_eval_steps", + type=int, + default=1, + help="The number of steps to run evaluation for.") + # Flags for gradient boosted trees config. + parser.add_argument( + "--depth", type=int, default=4, help="Maximum depth of weak learners.") + parser.add_argument( + "--l2", type=float, default=1.0, help="l2 regularization per batch.") + parser.add_argument( + "--learning_rate", + type=float, + default=0.1, + help="Learning rate (shrinkage weight) with which each new tree is added." + ) + parser.add_argument( + "--num_trees", + type=int, + default=None, + required=True, + help="Number of trees to grow before stopping.") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/boosted_trees/examples/mnist.py b/tensorflow/contrib/boosted_trees/examples/mnist.py index 7e34d2f2d36e1022845fa63ba44b2df7d8a2cf55..817c6eb3e1a79b38746418db9e5015e65ee70a50 100644 --- a/tensorflow/contrib/boosted_trees/examples/mnist.py +++ b/tensorflow/contrib/boosted_trees/examples/mnist.py @@ -22,7 +22,7 @@ r"""Demonstrates multiclass MNIST TF Boosted trees example. python tensorflow/contrib/boosted_trees/examples/mnist.py \ --output_dir="/tmp/mnist" --depth=4 --learning_rate=0.3 --batch_size=60000 \ --examples_per_layer=60000 --eval_batch_size=10000 --num_eval_steps=1 \ - --num_trees=10 --l2=1 --vmodule=training_ops=1 \ + --num_trees=10 --l2=1 --vmodule=training_ops=1 When training is done, accuracy on eval data is reported. Point tensorboard to the directory for the run to see how the training progresses: @@ -35,18 +35,13 @@ from __future__ import division from __future__ import print_function import argparse -import functools import sys import numpy as np import tensorflow as tf -from tensorflow.contrib import metrics as metrics_lib -from tensorflow.contrib.boosted_trees.estimator_batch import custom_loss_head -from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeEstimator +from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier from tensorflow.contrib.boosted_trees.proto import learner_pb2 -from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn import learn_runner -from tensorflow.python.ops import math_ops def get_input_fn(dataset_split, @@ -88,36 +83,13 @@ def _get_tfbt(output_dir): learner_config.growing_mode = growing_mode run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) - # Use Cross Entropy loss (the impl in losses is twice differentiable). - loss_fn = functools.partial( - losses.per_example_maxent_loss, num_classes=num_classes) - logit_dim = num_classes learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) - # Since we use custom head, we need to tell how accuracy is calculated. - def _multiclass_metrics(predictions, labels, weights): - """Prepares eval metrics for multiclass eval.""" - metrics = dict() - logits = predictions["scores"] - classes = math_ops.argmax(logits, 1) - metrics["accuracy"] = metrics_lib.streaming_accuracy( - classes, labels, weights) - return metrics - - metrics_fn = _multiclass_metrics - # Use custom loss head so we can provide our loss (cross entropy for - # multiclass). - head = custom_loss_head.CustomLossHead( - loss_fn=loss_fn, - link_fn=tf.identity, - logit_dimension=logit_dim, - metrics_fn=metrics_fn) - # Create a TF Boosted trees estimator that can take in custom loss. - estimator = GradientBoostedDecisionTreeEstimator( + estimator = GradientBoostedDecisionTreeClassifier( learner_config=learner_config, - head=head, + n_classes=num_classes, examples_per_layer=FLAGS.examples_per_layer, model_dir=output_dir, num_trees=FLAGS.num_trees, @@ -129,8 +101,8 @@ def _get_tfbt(output_dir): def _make_experiment_fn(output_dir): """Creates experiment for gradient boosted decision trees.""" data = tf.contrib.learn.datasets.mnist.load_mnist() - train_input_fn = get_input_fn(data.train, batch_size=256) - eval_input_fn = get_input_fn(data.validation, batch_size=5000) + train_input_fn = get_input_fn(data.train, FLAGS.batch_size) + eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size) return tf.contrib.learn.Experiment( estimator=_get_tfbt(output_dir), diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index d63be3d0415ab0d5f65f72073e34b8eb66ab747a..4b5d5ba0de6c3995ee2da7a44ab0ba099cbf1b35 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -15,7 +15,6 @@ #include #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h" -#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index d4d405c3a9a894e333fdf2278625d510cdeef1fe..107ff0d295bee530c1711a97849fbd3c6cdb2f00 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -81,6 +81,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "example_test", + size = "small", + srcs = ["utils/example_test.cc"], + deps = [ + ":utils", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "batch_features_test", size = "small", @@ -132,7 +144,6 @@ tf_cc_test( ":random_tree_gen", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:lib", "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -149,7 +160,6 @@ cc_library( deps = [ ":utils", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:testlib", ], @@ -197,7 +207,6 @@ tf_cc_test( srcs = ["quantiles/weighted_quantiles_buffer_test.cc"], deps = [ ":weighted_quantiles", - "//tensorflow/core", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -210,7 +219,6 @@ tf_cc_test( srcs = ["quantiles/weighted_quantiles_summary_test.cc"], deps = [ ":weighted_quantiles", - "//tensorflow/core", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -262,6 +270,8 @@ py_library( srcs = ["learner/batch/base_split_handler.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/boosted_trees:batch_ops_utils_py", + "//tensorflow/python:control_flow_ops", ], ) @@ -271,9 +281,13 @@ py_library( srcs_version = "PY2AND3", deps = [ ":base_split_handler", - "//tensorflow/contrib/boosted_trees:quantile_ops_py", "//tensorflow/contrib/boosted_trees:split_handler_ops_py", "//tensorflow/contrib/boosted_trees:stats_accumulator_ops_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", ], ) @@ -285,7 +299,15 @@ py_test( ":categorical_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", ], ) @@ -298,7 +320,14 @@ py_library( "//tensorflow/contrib/boosted_trees:quantile_ops_py", "//tensorflow/contrib/boosted_trees:split_handler_ops_py", "//tensorflow/contrib/boosted_trees:stats_accumulator_ops_py", - "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", ], ) @@ -310,7 +339,15 @@ py_test( ":ordinal_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", ], ) diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc index bd70586393eb062a46b6e242c6094ef0605804e2..f8750e7191673274772fc869c198dd5fbbefbc49 100644 --- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc @@ -50,10 +50,15 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, current_node.sparse_float_binary_split_default_left().split(); auto sparse_feature = example.sparse_float_features[split.feature_column()]; - node_id = !sparse_feature.has_value() || - sparse_feature.get_value() <= split.threshold() - ? split.left_id() - : split.right_id(); + // Feature id for the split when multivalent sparse float column, or 0 + // by default. + const int32 feature_id = split.feature_id(); + + node_id = + !sparse_feature[feature_id].has_value() || + sparse_feature[feature_id].get_value() <= split.threshold() + ? split.left_id() + : split.right_id(); break; } case TreeNode::kSparseFloatBinarySplitDefaultRight: { @@ -61,10 +66,14 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, current_node.sparse_float_binary_split_default_right().split(); auto sparse_feature = example.sparse_float_features[split.feature_column()]; - node_id = sparse_feature.has_value() && - sparse_feature.get_value() <= split.threshold() - ? split.left_id() - : split.right_id(); + // Feature id for the split when multivalent sparse float column, or 0 + // by default. + const int32 feature_id = split.feature_id(); + node_id = + sparse_feature[feature_id].has_value() && + sparse_feature[feature_id].get_value() <= split.threshold() + ? split.left_id() + : split.right_id(); break; } case TreeNode::kCategoricalIdBinarySplit: { diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc index c55d09807eaf3a9c9db1cfbbfdfc66aec8f25155..93924d429c19aef51b6f1d85655de3798a76e3e0 100644 --- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc @@ -27,13 +27,14 @@ class DecisionTreeTest : public ::testing::Test { protected: DecisionTreeTest() : batch_features_(2) { // Create a batch of two examples having one dense float, two sparse float - // and one sparse int features. + // and one sparse int features, and one sparse multi-column float feature + // (SparseFM). // The first example is missing the second sparse feature column and the // second example is missing the first sparse feature column. // This looks like the following: - // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | - // 0 | 7 | -3 | | 3 | - // 1 | -2 | | 4 | | + // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | SparseFM (3 cols) + // 0 | 7 | -3 | | 3 | 3.0 | | 1.0 + // 1 | -2 | | 4 | | 1.5 |3.5| auto dense_float_matrix = test::AsTensor({7.0f, -2.0f}, {2, 1}); auto sparse_float_indices1 = test::AsTensor({0, 0}, {1, 2}); auto sparse_float_values1 = test::AsTensor({-3.0f}); @@ -44,11 +45,21 @@ class DecisionTreeTest : public ::testing::Test { auto sparse_int_indices1 = test::AsTensor({0, 0}, {1, 2}); auto sparse_int_values1 = test::AsTensor({3}); auto sparse_int_shape1 = test::AsTensor({2, 1}); + + // Multivalent sparse feature. + auto multi_sparse_float_indices = + test::AsTensor({0, 0, 0, 2, 1, 0, 1, 1}, {4, 2}); + auto multi_sparse_float_values = + test::AsTensor({3.0f, 1.0f, 1.5f, 3.5f}); + auto multi_sparse_float_shape = test::AsTensor({2, 3}); + TF_EXPECT_OK(batch_features_.Initialize( - {dense_float_matrix}, {sparse_float_indices1, sparse_float_indices2}, - {sparse_float_values1, sparse_float_values2}, - {sparse_float_shape1, sparse_float_shape2}, {sparse_int_indices1}, - {sparse_int_values1}, {sparse_int_shape1})); + {dense_float_matrix}, + {sparse_float_indices1, sparse_float_indices2, + multi_sparse_float_indices}, + {sparse_float_values1, sparse_float_values2, multi_sparse_float_values}, + {sparse_float_shape1, sparse_float_shape2, multi_sparse_float_shape}, + {sparse_int_indices1}, {sparse_int_values1}, {sparse_int_shape1})); } template @@ -121,44 +132,90 @@ TEST_F(DecisionTreeTest, TraverseDenseBinarySplit) { } TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { - // Test first sparse feature which is missing for the second example. - DecisionTreeConfig tree_config1; - auto* split_node1 = tree_config1.add_nodes() - ->mutable_sparse_float_binary_split_default_left() - ->mutable_split(); - split_node1->set_feature_column(0); - split_node1->set_threshold(-20.0f); - split_node1->set_left_id(1); - split_node1->set_right_id(2); - tree_config1.add_nodes()->mutable_leaf(); - tree_config1.add_nodes()->mutable_leaf(); auto example_iterable = batch_features_.examples_iterable(0, 2); - - // Expect right child to be picked as !(-3 <= -20). - auto example_it = example_iterable.begin(); - EXPECT_EQ(2, DecisionTree::Traverse(tree_config1, 0, *example_it)); - - // Expect left child to be picked as default direction. - EXPECT_EQ(1, DecisionTree::Traverse(tree_config1, 0, *++example_it)); - + // Split on SparseF1. + // Test first sparse feature which is missing for the second example. + { + DecisionTreeConfig tree_config; + auto* split_node = tree_config.add_nodes() + ->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + split_node->set_feature_column(0); + split_node->set_threshold(-20.0f); + split_node->set_left_id(1); + split_node->set_right_id(2); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + + // Expect right child to be picked as !(-3 <= -20). + auto example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); + + // Expect left child to be picked as default direction. + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); + } + // Split on SparseF2. // Test second sparse feature which is missing for the first example. - DecisionTreeConfig tree_config2; - auto* split_node2 = tree_config2.add_nodes() - ->mutable_sparse_float_binary_split_default_right() - ->mutable_split(); - split_node2->set_feature_column(1); - split_node2->set_threshold(4.0f); - split_node2->set_left_id(1); - split_node2->set_right_id(2); - tree_config2.add_nodes()->mutable_leaf(); - tree_config2.add_nodes()->mutable_leaf(); - - // Expect right child to be picked as default direction. - example_it = example_iterable.begin(); - EXPECT_EQ(2, DecisionTree::Traverse(tree_config2, 0, *example_it)); - - // Expect left child to be picked as (4 <= 4). - EXPECT_EQ(1, DecisionTree::Traverse(tree_config2, 0, *++example_it)); + { + DecisionTreeConfig tree_config; + auto* split_node = tree_config.add_nodes() + ->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + split_node->set_feature_column(1); + split_node->set_threshold(4.0f); + split_node->set_left_id(1); + split_node->set_right_id(2); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + + // Expect right child to be picked as default direction. + auto example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); + + // Expect left child to be picked as (4 <= 4). + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); + } + // Split on SparseFM. + // Test second sparse feature which is missing for the first example. + { + DecisionTreeConfig tree_config; + auto* split_node = tree_config.add_nodes() + ->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + split_node->set_feature_column(2); + + split_node->set_left_id(1); + split_node->set_right_id(2); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + + // Split on first column + split_node->set_feature_id(0); + split_node->set_threshold(2.0f); + + // Both instances have this feature value. + auto example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); + + // Split on second column + split_node->set_feature_id(1); + split_node->set_threshold(5.0f); + + // First instance does not have it (default right), second does have it. + example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); + + // Split on third column + split_node->set_feature_id(2); + split_node->set_threshold(3.0f); + example_it = example_iterable.begin(); + + // First instance has it, second does not (default right). + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it)); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it)); + } } TEST_F(DecisionTreeTest, TraverseCategoricalIdBinarySplit) { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc index 12b377dda7852bb5a580c4ccc1d239709ef9bfc0..cf4f9a097a3368465fd4d9afb981bbaa68b4df49 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc @@ -94,10 +94,6 @@ Status BatchFeatures::Initialize( shape_flat(0) == batch_size_, errors::InvalidArgument( "Sparse float feature shape incompatible with batch size.")); - TF_CHECK_AND_RETURN_IF_ERROR( - shape_flat(1) <= 1, - errors::InvalidArgument( - "Sparse float features may not be multi-valent.")); auto tensor_shape = TensorShape({shape_flat(0), shape_flat(1)}); auto order_dims = sparse::SparseTensor::VarDimArray({0, 1}); sparse_float_feature_columns_.emplace_back(sparse_float_feature_indices, diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h index bb11dc9a0778c062c68433c001e7935388e0f45c..7a550d6f7328765d8815a947885e47fa0b0a8f8b 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h @@ -45,6 +45,22 @@ class BatchFeatures { std::vector sparse_int_feature_values_list, std::vector sparse_int_feature_shapes_list); + Status GetFeatureColumnSizes(int64* const num_dense_float_features, + int64* const num_sparse_float_features, + int64* const num_sparse_int_features) const { + QCHECK_NE(num_dense_float_features, nullptr); + QCHECK_NE(num_sparse_float_features, nullptr); + QCHECK_NE(num_sparse_int_features, nullptr); + *num_dense_float_features = dense_float_feature_columns_.size(); + *num_sparse_float_features = sparse_float_feature_columns_.size(); + *num_sparse_int_features = sparse_int_feature_columns_.size(); + if (*num_dense_float_features == 0 && *num_sparse_float_features == 0 && + *num_sparse_int_features == 0) { + return errors::FailedPrecondition("Not intialized yet."); + } + return Status::OK(); + } + // Creates an example iterable for the requested slice. ExamplesIterable examples_iterable(int64 example_start, int64 example_end) const { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc index 7f523d527adeb60d179bfce4bc5ef32e75e34ca2..9de3e32b097a151b3bd6f5c30df2db0938b65e9c 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc @@ -129,19 +129,6 @@ TEST_F(BatchFeaturesTest, SparseFloatFeatures_IncompatibleShape) { {sparse_float_feature_shape}, {}, {}, {})); } -TEST_F(BatchFeaturesTest, SparseFloatFeatures_Multivalent) { - BatchFeatures batch_features(2); - auto sparse_float_feature_indices = AsTensor({0, 0, 1, 0}, {2, 2}); - auto sparse_float_feature_values = AsTensor({3.0f, 7.0f}); - auto sparse_float_feature_shape = AsTensor({2, 2}); - auto expected_error = - InvalidArgument("Sparse float features may not be multi-valent."); - EXPECT_EQ(expected_error, batch_features.Initialize( - {}, {sparse_float_feature_indices}, - {sparse_float_feature_values}, - {sparse_float_feature_shape}, {}, {}, {})); -} - TEST_F(BatchFeaturesTest, SparseIntFeatures_WrongShapeIndices) { BatchFeatures batch_features(2); auto sparse_int_feature_indices = AsTensor({0, 0, 1, 0}); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/example.h b/tensorflow/contrib/boosted_trees/lib/utils/example.h index 4681eb06aa2c11a33db4d6e8ff3f0148ffd82917..9514416660c42d71a15d99e791e804e4bb6fff60 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/example.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/example.h @@ -16,6 +16,8 @@ #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ +#include +#include #include #include #include "tensorflow/contrib/boosted_trees/lib/utils/optional_value.h" @@ -24,6 +26,56 @@ namespace tensorflow { namespace boosted_trees { namespace utils { +// A matrix that given feature column id and feature value id will return +// either a value or an optional. First index indicates feature column, second +// index - the index of the value within this column - for single valued, it +// will be 0. +// Allows double-subscript access [][]. +template +class SparseMatrix { + typedef std::vector> SparseMap; + + class Proxy { + public: + Proxy(const int32 feature_column_idx, const SparseMap& values) + : feature_column_idx_(feature_column_idx), values_(values) {} + + OptionalValue operator[](int feature_idx) const { + auto value_iter = std::find_if( + values_.begin(), values_.end(), + [this, &feature_idx](const std::tuple& element) { + return std::get<0>(element) == feature_column_idx_ && + std::get<1>(element) == feature_idx; + }); + + if (value_iter == values_.end()) { + return OptionalValue(); + } + // There is this feature column and feature id. + return OptionalValue(std::get<2>(*value_iter)); + } + + private: + int32 feature_column_idx_; + const SparseMap& values_; + }; + + public: + void addElement(const int32 feature_column_idx, const int32 feature_idx, + const T value) { + values_.emplace_back(feature_column_idx, feature_idx, value); + } + + void clear() { values_.clear(); } + + Proxy operator[](int feature_column_idx) const { + return Proxy(feature_column_idx, values_); + } + + private: + SparseMap values_; +}; + // Holds data for one example and enables lookup by feature column. struct Example { // Default constructor creates an empty example. @@ -35,7 +87,9 @@ struct Example { // Dense and sparse float features indexed by feature column. // TODO(salehay): figure out a design to support multivalent float features. std::vector dense_float_features; - std::vector> sparse_float_features; + // Sparse float features are allowed to be multivalent and thus can be + // represented as a sparse matrix. + SparseMatrix sparse_float_features; // Sparse integer features indexed by feature column. // Note that all integer features are assumed to be categorical, i.e. will diff --git a/tensorflow/contrib/boosted_trees/lib/utils/example_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/example_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f78fd25022e3e9958c0f521927e7d83eddd8f97f --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/utils/example_test.cc @@ -0,0 +1,81 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/utils/example.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace boosted_trees { +namespace utils { +namespace { + +class ExampleTest : public ::testing::Test {}; + +TEST_F(ExampleTest, TestSparseMatrix) { + // Create the following matrix: + // row id | | 0.4 | 0.3 + // 0 | 1 | | 2 + // 1 | 3 | 1 | 5 + // 2 | | | -4 + // 3 | | | + SparseMatrix matrix; + matrix.addElement(0, 1, 0.4f); + matrix.addElement(0, 2, 0.3f); + matrix.addElement(1, 0, 1.f); + matrix.addElement(1, 2, 2.f); + matrix.addElement(2, 0, 3.f); + matrix.addElement(2, 1, 1.f); + matrix.addElement(2, 2, 5.f); + matrix.addElement(3, 2, -4.f); + + // Row 0. + EXPECT_FALSE(matrix[0][0].has_value()); + EXPECT_TRUE(matrix[0][1].has_value()); + EXPECT_EQ(0.4f, matrix[0][1].get_value()); + EXPECT_TRUE(matrix[0][2].has_value()); + EXPECT_EQ(0.3f, matrix[0][2].get_value()); + + // Row 1. + EXPECT_TRUE(matrix[1][0].has_value()); + EXPECT_EQ(1.f, matrix[1][0].get_value()); + EXPECT_FALSE(matrix[1][1].has_value()); + EXPECT_TRUE(matrix[1][2].has_value()); + EXPECT_EQ(2.f, matrix[1][2].get_value()); + + // Row 2. + EXPECT_TRUE(matrix[2][0].has_value()); + EXPECT_EQ(3.f, matrix[2][0].get_value()); + EXPECT_TRUE(matrix[2][1].has_value()); + EXPECT_EQ(1.f, matrix[2][1].get_value()); + EXPECT_TRUE(matrix[2][2].has_value()); + EXPECT_EQ(5.f, matrix[2][2].get_value()); + + // Row 3. + EXPECT_FALSE(matrix[3][0].has_value()); + EXPECT_FALSE(matrix[3][1].has_value()); + EXPECT_TRUE(matrix[3][2].has_value()); + EXPECT_EQ(-4.f, matrix[3][2].get_value()); + + // Row 4. + EXPECT_FALSE(matrix[4][0].has_value()); + EXPECT_FALSE(matrix[4][1].has_value()); + EXPECT_FALSE(matrix[4][2].has_value()); +} + +} // namespace +} // namespace utils +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc index c73dc8e15d42f2c80078cf628b5cd5773f5860ff..3b287b1dcfe9fd9be3a16e3ba8d58d13b4efbb4f 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc @@ -73,8 +73,6 @@ Iterator::Iterator(ExamplesIterable* iter, int64 example_idx) // Pre-size example features. example_.dense_float_features.resize( iter_->dense_float_column_values_.size()); - example_.sparse_float_features.resize( - iter_->sparse_float_column_values_.size()); example_.sparse_int_features.resize(iter_->sparse_int_column_values_.size()); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h index 67efb82a227a3d7e92cdf5c8307a6f04c45fb617..72b7486872ecef7adb6bc23e37b4e9c279f4b7f5 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h @@ -87,19 +87,34 @@ class ExamplesIterable { // Get sparse float values per column. auto& sparse_float_features = example_.sparse_float_features; + sparse_float_features.clear(); + // Iterate through each sparse float feature column. for (size_t sparse_float_idx = 0; - sparse_float_idx < sparse_float_features.size(); + sparse_float_idx < iter_->sparse_float_column_iterables_.size(); ++sparse_float_idx) { + // Get range for values tensor. const auto& row_range = (*sparse_float_column_iterators_[sparse_float_idx]); DCHECK_EQ(example_idx_, row_range.example_idx); + // If the example has this feature column. if (row_range.start < row_range.end) { - DCHECK_EQ(1, row_range.end - row_range.start); - sparse_float_features[sparse_float_idx] = OptionalValue( - iter_->sparse_float_column_values_[sparse_float_idx]( - row_range.start)); - } else { - sparse_float_features[sparse_float_idx] = OptionalValue(); + // Retrieve original indices tensor. + const TTypes::ConstMatrix& indices = + iter_->sparse_float_column_iterables_[sparse_float_idx] + .sparse_indices(); + + // For each value. + for (int64 row_idx = row_range.start; row_idx < row_range.end; + ++row_idx) { + // Get the feature id for the feature column and the value. + const int32 feature_id = indices(row_idx, 1); + DCHECK_EQ(example_idx_, indices(row_idx, 0)); + + // Save the value to our sparse matrix. + sparse_float_features.addElement( + sparse_float_idx, feature_id, + iter_->sparse_float_column_values_[sparse_float_idx](row_idx)); + } } } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc index d93bcc8aa67102fcdacf130d90769514ce6c8170..05c166edc61b02d87f83f4d1b35f9825a93209ec 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc @@ -26,17 +26,17 @@ class ExamplesIterableTest : public ::testing::Test {}; TEST_F(ExamplesIterableTest, Iterate) { // Create a batch of 8 examples having one dense float, two sparse float and - // two sparse int features. + // two sparse int features. Second sparse float feature is multivalent. // The data looks like the following: // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | SparseI2 | - // 0 | 7 | -3 | | 1, 8 | | - // 1 | -2 | | 4 | 0 | 7 | - // 2 | 8 | 0 | | | 13 | - // 3 | 1 | 5 | 7 | 2, 0 | 4 | - // 4 | 0 | 0 | | | 0 | - // 5 | -4 | | 9 | | | - // 6 | 7 | | | | | - // 7 | -2 | | -4 | 5 | | + // 0 | 7 | -3 | | 1 | 1, 8 | | + // 1 | -2 | | 4 | | 0 | 7 | + // 2 | 8 | 0 | | 3 | | 13 | + // 3 | 1 | 5 | 7 | | 2, 0 | 4 | + // 4 | 0 | 0 | | 4.3 | | 0 | + // 5 | -4 | | 9 | 0.8 | | | + // 6 | 7 | | | | | | + // 7 | -2 | | -4 | | 5 | | auto dense_float_tensor = test::AsTensor( {7.0f, -2.0f, 8.0f, 1.0f, 0.0f, -4.0f, 7.0f, -2.0f}, {8, 1}); auto sparse_float_indices1 = @@ -45,10 +45,11 @@ TEST_F(ExamplesIterableTest, Iterate) { auto sparse_float_shape1 = TensorShape({8, 1}); sparse::SparseTensor sparse_float_tensor1( sparse_float_indices1, sparse_float_values1, sparse_float_shape1); - auto sparse_float_indices2 = - test::AsTensor({1, 0, 3, 0, 5, 0, 7, 0}, {4, 2}); - auto sparse_float_values2 = test::AsTensor({4.0f, 7.0f, 9.0f, -4.0f}); - auto sparse_float_shape2 = TensorShape({8, 1}); + auto sparse_float_indices2 = test::AsTensor( + {0, 1, 1, 0, 2, 1, 3, 0, 4, 1, 5, 0, 5, 1, 7, 0}, {8, 2}); + auto sparse_float_values2 = + test::AsTensor({1.f, 4.0f, 3.f, 7.0f, 4.3f, 9.0f, 0.8f, -4.0f}); + auto sparse_float_shape2 = TensorShape({8, 2}); sparse::SparseTensor sparse_float_tensor2( sparse_float_indices2, sparse_float_values2, sparse_float_shape2); auto sparse_int_indices1 = @@ -67,15 +68,19 @@ TEST_F(ExamplesIterableTest, Iterate) { auto validate_example_features = [](int64 example_idx, const Example& example) { EXPECT_EQ(1, example.dense_float_features.size()); - EXPECT_EQ(2, example.sparse_float_features.size()); switch (example_idx) { case 0: { EXPECT_EQ(0, example.example_idx); EXPECT_EQ(7.0f, example.dense_float_features[0]); - EXPECT_TRUE(example.sparse_float_features[0].has_value()); - EXPECT_EQ(-3.0f, example.sparse_float_features[0].get_value()); - EXPECT_FALSE(example.sparse_float_features[1].has_value()); + // SparseF1. + EXPECT_TRUE(example.sparse_float_features[0][0].has_value()); + EXPECT_EQ(-3.0f, example.sparse_float_features[0][0].get_value()); + // SparseF2 - multivalent. + EXPECT_FALSE(example.sparse_float_features[1][0].has_value()); + EXPECT_TRUE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(1.0f, example.sparse_float_features[1][1].get_value()); + EXPECT_EQ(2, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[0].count(1)); EXPECT_EQ(1, example.sparse_int_features[0].count(8)); @@ -84,9 +89,13 @@ TEST_F(ExamplesIterableTest, Iterate) { case 1: { EXPECT_EQ(1, example.example_idx); EXPECT_EQ(-2.0f, example.dense_float_features[0]); - EXPECT_FALSE(example.sparse_float_features[0].has_value()); - EXPECT_TRUE(example.sparse_float_features[1].has_value()); - EXPECT_EQ(4.0f, example.sparse_float_features[1].get_value()); + // SparseF1. + EXPECT_FALSE(example.sparse_float_features[0][0].has_value()); + // SparseF2. + EXPECT_TRUE(example.sparse_float_features[1][0].has_value()); + EXPECT_EQ(4.0f, example.sparse_float_features[1][0].get_value()); + EXPECT_FALSE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(1, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[0].count(0)); EXPECT_EQ(1, example.sparse_int_features[1].size()); @@ -95,9 +104,14 @@ TEST_F(ExamplesIterableTest, Iterate) { case 2: { EXPECT_EQ(2, example.example_idx); EXPECT_EQ(8.0f, example.dense_float_features[0]); - EXPECT_TRUE(example.sparse_float_features[0].has_value()); - EXPECT_EQ(0.0f, example.sparse_float_features[0].get_value()); - EXPECT_FALSE(example.sparse_float_features[1].has_value()); + // SparseF1. + EXPECT_TRUE(example.sparse_float_features[0][0].has_value()); + EXPECT_EQ(0.0f, example.sparse_float_features[0][0].get_value()); + // SparseF2. + EXPECT_FALSE(example.sparse_float_features[1][0].has_value()); + EXPECT_TRUE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(3.f, example.sparse_float_features[1][1].get_value()); + EXPECT_EQ(0, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[1].size()); EXPECT_EQ(1, example.sparse_int_features[1].count(13)); @@ -105,10 +119,14 @@ TEST_F(ExamplesIterableTest, Iterate) { case 3: { EXPECT_EQ(3, example.example_idx); EXPECT_EQ(1.0f, example.dense_float_features[0]); - EXPECT_TRUE(example.sparse_float_features[0].has_value()); - EXPECT_EQ(5.0f, example.sparse_float_features[0].get_value()); - EXPECT_TRUE(example.sparse_float_features[1].has_value()); - EXPECT_EQ(7.0f, example.sparse_float_features[1].get_value()); + // SparseF1. + EXPECT_TRUE(example.sparse_float_features[0][0].has_value()); + EXPECT_EQ(5.0f, example.sparse_float_features[0][0].get_value()); + // SparseF2. + EXPECT_TRUE(example.sparse_float_features[1][0].has_value()); + EXPECT_EQ(7.0f, example.sparse_float_features[1][0].get_value()); + EXPECT_FALSE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(2, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[0].count(2)); EXPECT_EQ(1, example.sparse_int_features[0].count(0)); @@ -118,9 +136,14 @@ TEST_F(ExamplesIterableTest, Iterate) { case 4: { EXPECT_EQ(4, example.example_idx); EXPECT_EQ(0.0f, example.dense_float_features[0]); - EXPECT_TRUE(example.sparse_float_features[0].has_value()); - EXPECT_EQ(0.0f, example.sparse_float_features[0].get_value()); - EXPECT_FALSE(example.sparse_float_features[1].has_value()); + // SparseF1. + EXPECT_TRUE(example.sparse_float_features[0][0].has_value()); + EXPECT_EQ(0.0f, example.sparse_float_features[0][0].get_value()); + // SparseF2. + EXPECT_FALSE(example.sparse_float_features[1][0].has_value()); + EXPECT_TRUE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(4.3f, example.sparse_float_features[1][1].get_value()); + EXPECT_EQ(0, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[1].size()); EXPECT_EQ(1, example.sparse_int_features[1].count(0)); @@ -128,24 +151,37 @@ TEST_F(ExamplesIterableTest, Iterate) { case 5: { EXPECT_EQ(5, example.example_idx); EXPECT_EQ(-4.0f, example.dense_float_features[0]); - EXPECT_FALSE(example.sparse_float_features[0].has_value()); - EXPECT_TRUE(example.sparse_float_features[1].has_value()); - EXPECT_EQ(9.0f, example.sparse_float_features[1].get_value()); + // SparseF1. + EXPECT_FALSE(example.sparse_float_features[0][0].has_value()); + // SparseF2. + EXPECT_TRUE(example.sparse_float_features[1][0].has_value()); + EXPECT_EQ(9.0f, example.sparse_float_features[1][0].get_value()); + EXPECT_TRUE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(0.8f, example.sparse_float_features[1][1].get_value()); + EXPECT_EQ(0, example.sparse_int_features[0].size()); } break; case 6: { EXPECT_EQ(6, example.example_idx); EXPECT_EQ(7.0f, example.dense_float_features[0]); - EXPECT_FALSE(example.sparse_float_features[0].has_value()); - EXPECT_FALSE(example.sparse_float_features[1].has_value()); + // SparseF1. + EXPECT_FALSE(example.sparse_float_features[0][0].has_value()); + // SparseF2. + EXPECT_FALSE(example.sparse_float_features[1][0].has_value()); + EXPECT_FALSE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(0, example.sparse_int_features[0].size()); } break; case 7: { EXPECT_EQ(7, example.example_idx); EXPECT_EQ(-2.0f, example.dense_float_features[0]); - EXPECT_FALSE(example.sparse_float_features[0].has_value()); - EXPECT_TRUE(example.sparse_float_features[1].has_value()); - EXPECT_EQ(-4.0f, example.sparse_float_features[1].get_value()); + // SparseF1. + EXPECT_FALSE(example.sparse_float_features[0][0].has_value()); + // SparseF2. + EXPECT_TRUE(example.sparse_float_features[1][0].has_value()); + EXPECT_EQ(-4.0f, example.sparse_float_features[1][0].get_value()); + EXPECT_FALSE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(1, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[0].count(5)); } break; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h index 78a5752730cb793394c41c56ab83b084a6f76088..9664c9d1c6a0c0c8b1bbd1506944c54d2310c611 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h @@ -112,6 +112,8 @@ class SparseColumnIterable { int64 example_start() const { return example_start_; } int64 example_end() const { return example_end_; } + const TTypes::ConstMatrix& sparse_indices() const { return ix_; } + private: // Sparse indices matrix. TTypes::ConstMatrix ix_; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc index 7792bd8c66c53c0f11cff113c3e5526c6d50dbb8..0138aae3dbd3773241cb6644db625b99f9bf1372 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc @@ -34,19 +34,19 @@ TEST_F(SparseColumnIterableTest, Empty) { } TEST_F(SparseColumnIterableTest, Iterate) { - // 8 examples having 7 sparse features with the third multi-valent. + // 8 examples having 7 sparse features with the 3rd and 7th multi-valent. // This can be visualized like the following: // Instance | Sparse | - // 0 | x | + // 0 | x | // 1 | | // 2 | | // 3 | xxx | - // 4 | x | + // 4 | x | // 5 | | // 6 | | - // 7 | xx | + // 7 | x x | const auto indices = - AsTensor({0, 0, 3, 0, 3, 1, 3, 2, 4, 0, 7, 0, 7, 1}, {7, 2}); + AsTensor({0, 0, 3, 0, 3, 1, 3, 2, 4, 0, 7, 0, 7, 2}, {7, 2}); auto validate_example_range = [](const ExampleRowRange& range) { switch (range.example_idx) { diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index 2e9d45efd71adef828a55e54f48d2740b8c1a12e..f14abf45a517ad7c4c6d7bb1ab88b7a1d47d6fb6 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -53,6 +53,9 @@ message DenseFloatBinarySplit { // Float feature column and split threshold describing // the rule feature <= threshold. int32 feature_column = 1; + // If feature column is multivalent, this holds the index of the feature for + // the split. Defaults to 0. + int32 feature_id = 5; float threshold = 2; // Node children indexing into a contiguous diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index f8f4b43a072a91f1563b20d6ba3aef82fd4b9896..5a917ca42897a263bf9f868393453ba232745e65 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -344,7 +344,7 @@ class GradientBoostedDecisionTreeModel(object): learner_config.num_classes == 2) def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): - """Runs prediciton and returns a dictionary of the prediction results. + """Runs prediction and returns a dictionary of the prediction results. Args: ensemble_handle: ensemble resource handle. diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index 4f128b230180d8e8070f63c369bc7fc2f3d24376..1e8b3ac08a74a94a0e5729e42ace91398a7b5c94 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -101,7 +101,10 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): unweighted_loss = array_ops.expand_dims(-math_ops.log(probs_for_real_class), 1) - return unweighted_loss * weights, control_flow_ops.no_op() + if weights is None: + return unweighted_loss, control_flow_ops.no_op() + else: + return unweighted_loss * weights, control_flow_ops.no_op() def per_example_squared_loss(labels, weights, predictions): diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 77e6ecb443dd3f0f7a96b7453f558d58f01c7a21..284ad5cdb9abf374650940ade7bb36663d72c0dd 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -47,6 +47,7 @@ class DecisionTreeEnsembleResource : public StampedResource { int32 num_trees() const { return decision_tree_ensemble_->trees_size(); } bool InitFromSerialized(const string& serialized, const int64 stamp_token) { + CHECK_EQ(stamp(), -1) << "Must Reset before Init."; if (ParseProtoUnlimited(decision_tree_ensemble_, serialized)) { set_stamp(stamp_token); return true; @@ -126,7 +127,7 @@ class DecisionTreeEnsembleResource : public StampedResource { // Resets the resource and frees the protos in arena. // Caller needs to hold the mutex lock while calling this. - void Reset() { + virtual void Reset() { // Reset stamp. set_stamp(-1); diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 09ec7e42c7eede97b9c7eeee329fe0649365869e..56f930a9a8d32c5c3a025163ef56c9562f17d864 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -23,7 +23,9 @@ load( filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", @@ -34,9 +36,7 @@ filegroup( tf_kernel_library( name = "bigquery_reader_ops", - srcs = [ - "bigquery_reader_ops.cc", - ], + srcs = ["bigquery_reader_ops.cc"], visibility = ["//visibility:public"], deps = [ ":bigquery_table_accessor", @@ -50,12 +50,8 @@ tf_kernel_library( cc_library( name = "bigquery_table_accessor", - srcs = [ - "bigquery_table_accessor.cc", - ], - hdrs = [ - "bigquery_table_accessor.h", - ], + srcs = ["bigquery_table_accessor.cc"], + hdrs = ["bigquery_table_accessor.h"], copts = tf_copts(), linkstatic = 1, deps = [ @@ -64,7 +60,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform/cloud:curl_http_request", "//tensorflow/core/platform/cloud:google_auth_provider", - "//tensorflow/core/platform/cloud:http_request", ], alwayslink = 1, ) @@ -88,8 +83,6 @@ tf_cc_test( tf_proto_library( name = "bigquery_table_partition_proto", - srcs = [ - "bigquery_table_partition.proto", - ], + srcs = ["bigquery_table_partition.proto"], cc_api_version = 2, ) diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake index d98579d2077f0a3bc58e6466ee830e53f44f40cb..836889895567f679d9960e29ece1600d1a7a58eb 100644 --- a/tensorflow/contrib/cmake/external/cub.cmake +++ b/tensorflow/contrib/cmake/external/cub.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip) -set(cub_HASH SHA256=b7ead9e291d34ffa8074243541c1380d63be63f88de23de8ee548db573b72ebe) +set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip) +set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31) set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive) diff --git a/tensorflow/contrib/cmake/external/gif.cmake b/tensorflow/contrib/cmake/external/gif.cmake index 5cb719b8787781084335779960887613df90217d..3d53c51fffcec1602a3b5553cdf3b225e3b0ae46 100644 --- a/tensorflow/contrib/cmake/external/gif.cmake +++ b/tensorflow/contrib/cmake/external/gif.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/) -set(gif_URL http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz) +set(gif_URL https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz) set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1) set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install) set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif) diff --git a/tensorflow/contrib/cmake/external/jpeg.cmake b/tensorflow/contrib/cmake/external/jpeg.cmake index 058f554b8f2ffc4f925012e8772c684965304833..d9a165e856c588880ebdf996666d70c9e7f53da8 100644 --- a/tensorflow/contrib/cmake/external/jpeg.cmake +++ b/tensorflow/contrib/cmake/external/jpeg.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(jpeg_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/jpeg_archive) -set(jpeg_URL http://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz) +set(jpeg_URL https://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz) set(jpeg_HASH SHA256=3a753ea48d917945dd54a2d97de388aa06ca2eb1066cbfdc6652036349fe05a7) set(jpeg_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jpeg/src/jpeg) set(jpeg_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/jpeg/install) diff --git a/tensorflow/contrib/cmake/external/lmdb.cmake b/tensorflow/contrib/cmake/external/lmdb.cmake index 28ec833babe8f8e600c7c0179dff511ce4d26105..79971b7cfc3c72e4b6290ccb71d40a20d1180c01 100644 --- a/tensorflow/contrib/cmake/external/lmdb.cmake +++ b/tensorflow/contrib/cmake/external/lmdb.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(lmdb_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/lmdb) -set(lmdb_URL http://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz) +set(lmdb_URL https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz) set(lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326) set(lmdb_BUILD ${CMAKE_BINARY_DIR}/lmdb/src/lmdb) set(lmdb_INSTALL ${CMAKE_BINARY_DIR}/lmdb/install) diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index d600d8c3c0d30ec517d0abc4bac94c588b5268d4..1e300e21df17eeee0abfc2becdab746fbfc62ff6 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -15,8 +15,8 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) -set(PROTOBUF_URL https://github.com/mrry/protobuf.git) # Includes MSVC fix. -set(PROTOBUF_TAG 1d2c7b6c7376f396c8c7dd9b6afd2d4f83f3cb05) +set(PROTOBUF_URL https://github.com/google/protobuf.git) +set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) if(WIN32) set(protobuf_STATIC_LIBRARIES diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake index a35d8654fb6fa5f5b5d230ffbc061d050e5aeb5e..2d2451521c0f9127e2c76e6270694ac21fe8db93 100644 --- a/tensorflow/contrib/cmake/external/snappy.cmake +++ b/tensorflow/contrib/cmake/external/snappy.cmake @@ -47,4 +47,4 @@ ExternalProject_Add(snappy ) # actually enables snappy in the source code -add_definitions(-DSNAPPY) \ No newline at end of file +add_definitions(-DTF_USE_SNAPPY) diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index c5a101812710f0e6eb0aa8816acd2b395e7f7472..f3882e8cf76c6dad31371fc340de959c05411a2f 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -21,6 +21,8 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" + "${tensorflow_source_dir}/tensorflow/c/eager/tape.cc" + "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc" diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 61c6686ee0e50a69c173ef47dd1e72d9bb5a982f..65565aad7ea5469926f320839455cd884a343713 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -33,6 +33,8 @@ else(tensorflow_BUILD_ALL_KERNELS) "${tensorflow_source_dir}/tensorflow/core/kernels/matmul_op.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/no_op.h" "${tensorflow_source_dir}/tensorflow/core/kernels/no_op.cc" + "${tensorflow_source_dir}/tensorflow/core/kernels/ops_util.h" + "${tensorflow_source_dir}/tensorflow/core/kernels/ops_util.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/sendrecv_ops.h" "${tensorflow_source_dir}/tensorflow/core/kernels/sendrecv_ops.cc" ) @@ -65,6 +67,8 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/training_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc" @@ -74,6 +78,13 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) #"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/zero_initializer_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/bipartite_match_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/ops/distort_image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" @@ -167,6 +178,7 @@ endif(WIN32) file(GLOB_RECURSE tf_core_gpu_kernels_srcs "${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/zero_initializer_op_gpu.cu.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/*.cu.cc" ) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 78bccc08a36b3be922867b410c75cda87e5d0983..15e9a4c461385a120aa3152c56d7284729f71eb5 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== set(tf_op_lib_names + "audio_ops" "array_ops" "bitwise_ops" "candidate_sampling_ops" @@ -43,6 +44,7 @@ set(tf_op_lib_names "state_ops" "stateless_random_ops" "string_ops" + "summary_ops" "training_ops" ) @@ -84,6 +86,8 @@ GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(input_pipeline "${tensorflow_source_dir}/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(image_distort_image "${tensorflow_source_dir}/tensorflow/contrib/image/ops/distort_image_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(image_sirds "${tensorflow_source_dir}/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 1e78f1e983f5b90d8a06c2baa08109193c4aa172..1b9fd514fde8c2b3de3de468bdd88adf50c94382 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -266,12 +266,14 @@ add_python_module("tensorflow/python/keras/_impl/keras/utils") add_python_module("tensorflow/python/keras/_impl/keras/wrappers") add_python_module("tensorflow/python/kernel_tests") add_python_module("tensorflow/python/kernel_tests/distributions") +add_python_module("tensorflow/python/kernel_tests/linalg") add_python_module("tensorflow/python/layers") add_python_module("tensorflow/python/lib") add_python_module("tensorflow/python/lib/core") add_python_module("tensorflow/python/lib/io") add_python_module("tensorflow/python/ops") add_python_module("tensorflow/python/ops/distributions") +add_python_module("tensorflow/python/ops/linalg") add_python_module("tensorflow/python/ops/losses") add_python_module("tensorflow/python/platform") add_python_module("tensorflow/python/platform/default") @@ -345,6 +347,8 @@ add_python_module("tensorflow/contrib/distributions/python") add_python_module("tensorflow/contrib/distributions/python/kernel_tests") add_python_module("tensorflow/contrib/distributions/python/ops") add_python_module("tensorflow/contrib/distributions/python/ops/bijectors") +add_python_module("tensorflow/contrib/eager") +add_python_module("tensorflow/contrib/eager/python") add_python_module("tensorflow/contrib/estimator") add_python_module("tensorflow/contrib/estimator/python") add_python_module("tensorflow/contrib/estimator/python/estimator") @@ -497,6 +501,7 @@ add_python_module("tensorflow/contrib/lookup") add_python_module("tensorflow/contrib/losses") add_python_module("tensorflow/contrib/losses/python") add_python_module("tensorflow/contrib/losses/python/losses") +add_python_module("tensorflow/contrib/losses/python/metric_learning") add_python_module("tensorflow/contrib/makefile") add_python_module("tensorflow/contrib/makefile/test") add_python_module("tensorflow/contrib/memory_stats") @@ -637,6 +642,7 @@ add_python_module("tensorflow/contrib/reduce_slice_ops/ops") add_python_module("tensorflow/contrib/reduce_slice_ops/python") add_python_module("tensorflow/contrib/reduce_slice_ops/python/kernel_tests") add_python_module("tensorflow/contrib/reduce_slice_ops/python/ops") +add_python_module("tensorflow/contrib/summary") # Generate the tensorflow.python.platform.build_info module. set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py") @@ -773,6 +779,10 @@ GENERATE_PYTHON_OP_LIB("contrib_input_pipeline_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/input_pipeline/ops/gen_input_pipeline_ops.py) GENERATE_PYTHON_OP_LIB("contrib_image_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/image/ops/gen_image_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_image_distort_image_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/image/ops/gen_distort_image_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_image_sirds_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/image/ops/gen_single_image_random_dot_stereograms_ops.py) GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py) GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" @@ -805,6 +815,8 @@ GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py) +GENERATE_PYTHON_OP_LIB("summary_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/summary/gen_summary_ops.py) add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES}) add_dependencies(tf_python_ops tf_python_op_gen_main) @@ -867,6 +879,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_writer.cc" "${tensorflow_source_dir}/tensorflow/python/util/kernel_registry.h" "${tensorflow_source_dir}/tensorflow/python/util/kernel_registry.cc" + "${tensorflow_source_dir}/tensorflow/python/util/util.h" + "${tensorflow_source_dir}/tensorflow/python/util/util.cc" "${tensorflow_source_dir}/tensorflow/cc/framework/ops.cc" "${tensorflow_source_dir}/tensorflow/cc/framework/scope.cc" "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.cc" diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index ba78e87ac04d365c4c28273768111ba1fb6e783d..77d21249148cc900a1bb4fc2742956aee47734de 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -152,6 +152,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/training/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/image/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/integration_test.py" "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/python/kernel_tests/*_test.py" @@ -178,6 +179,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) # exclude the ones we don't want set(tf_test_src_py_exclude + # generally excluded + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py" + # Python source line inspection tests are flaky on Windows (b/36375074). "${tensorflow_source_dir}/tensorflow/python/debug/cli/analyzer_cli_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/cli/profile_analyzer_cli_test.py" @@ -187,19 +191,16 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" # generally not working - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/resource_variable_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/pprof_profiler_test.py" # flaky test "${tensorflow_source_dir}/tensorflow/python/profiler/internal/run_metadata_test.py" + # Fails because uses data dependencies with bazel "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" # requires scipy "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py" - # flaky tests + # Takes very long to run without sharding (defined in bazel build file). "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cwise_ops_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py" # Loading resources in contrib doesn't seem to work on Windows "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/client/random_forest_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py" @@ -212,44 +213,58 @@ if (tensorflow_BUILD_PYTHON_TESTS) if (WIN32) set(tf_test_src_py_exclude ${tf_test_src_py_exclude} - # generally excluded - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py" - # TODO: failing tests. # Nothing critical in here but should get this list down to [] # The failing list is grouped by failure source + # stl on windows handles overflows different "${tensorflow_source_dir}/tensorflow/python/kernel_tests/as_string_op_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cast_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/string_to_number_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/clip_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/tensor_array_ops_test.py" # Needs portpicker. - # Matrix_set_diag failing on GPU on windows. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cholesky_op_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/diag_op_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg_ops_test.py" - # misc - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reshape_op_test.py" - "${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py" # Depends on gemmlowp -> pthread. + # Numerical issues, calculations off. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py" + # Float division by zero + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" + # Flaky, for unknown reasons. Cannot reproduce in terminal. Revisit once we can get stack traces. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/batch_matmul_op_test.py" + # Flaky because of local cluster creation. + "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" + "${tensorflow_source_dir}tensorflow/python/training/localhost_cluster_performance_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" + # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py" + # IteratorGetMax OutOfRangeError + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py" + # Depends on gemmlowp -> pthread + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py" # int32/int64 mixup + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cast_op_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py" + # Windows file management related issues. + "${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" # training tests "${tensorflow_source_dir}/tensorflow/python/training/basic_session_run_hooks_test.py" # Needs tf.contrib fix. - "${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows. "${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename. - "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/training/server_lib_test.py" # Test occasionally deadlocks. - - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops + "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_multi_gpu_test.py" # Fails on multiple GPUs. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" # numerical issues + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg_grad_test.py" # cudaSolver handle creation fails. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops # Dataset tests - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/dataset_constructor_op_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" # Broken tensorboard test due to cmake issues. - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 # tensor_forest tests (also note that we exclude the hybrid tests for now) @@ -258,8 +273,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py" # Bad placement. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/topn_test.py" # Results inaccurate "${tensorflow_source_dir}/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py" # No libcurl support - # Newly running on Windows since TensorBoard backend move. Fail on Windows and need debug. - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. # Dask.Dataframe bugs on Window Build "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py" "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py" @@ -268,39 +281,19 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Need extra build "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_distribution_test.py" "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/depthtospace_op_test.py" # QuantizeV2 + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/spacetodepth_op_test.py" # QuantizeV2 # Windows Path "${tensorflow_source_dir}/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py" #TODO: Fix path - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/models_test.py" - # Related to Windows Multiprocessing https://github.com/fchollet/keras/issues/5071 - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/engine/training_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/callbacks_test.py" - # Scipy needed - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/kmeans_test.py" "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py" - # Failing with TF 1.3 (TODO) - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py" + # Numpy upgrade needed? "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py" # Test should only be run manually "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/svd_op_test.py" ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 94aff13a49f5380d5804e190b33613fd42dcaebc..2108e42bce4eba1eed158fe85888f1699a69ba7e 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -173,12 +173,12 @@ class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): with self.test_session(): - x = constant_op.constant(3) - y_nc = math_ops.add(x, x, name="not_compiled") + x = constant_op.constant([[3]]) + y_nc = math_ops.matmul(x, x, name="not_compiled") with jit.experimental_jit_scope(): - y_c = math_ops.add(y_nc, y_nc, name="compiled") + y_c = math_ops.matmul(y_nc, y_nc, name="compiled") x_grads = gradients.gradients([y_c], [x])[0] - operations = x_grads.graph.get_operations() + operations = x.graph.get_operations() c_grad_ops = [ op for op in operations if "gradients/compiled" in op.name] nc_grad_ops = [ @@ -191,19 +191,19 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "No attr named"): ncg.get_attr("_XlaCompile") - # d/dx (4 * x) - self.assertAllClose(4, x_grads.eval()) + # d/dx (x ** 4) = 4 * (x ** 3) + self.assertAllClose([[108]], x_grads.eval()) def testCompilationGradientScopeNames(self): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(): # XlaScope 0 - a1 = constant_op.constant(1) - a1t = a1 + a1 + a1 = constant_op.constant([[1]]) + a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(): # XlaScope 1 - a2 = constant_op.constant(1) - a2t = a2 + a2 + a2 = constant_op.constant([[1]]) + a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) @@ -220,12 +220,12 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 0 - a1 = constant_op.constant(1) - a1t = a1 + a1 + a1 = constant_op.constant([[1]]) + a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 1 - a2 = constant_op.constant(1) - a2t = a2 + a2 + a2 = constant_op.constant([[1]]) + a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index d4214587cd1a0fa684710d37083028f9af0425d9..ae9413fdd63cafed306b10a0f68f3fc0315a22c3 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -54,7 +54,7 @@ tf_gen_op_wrapper_py( ) tf_custom_op_py_library( - name = "cudnn_rnn_py", + name = "cudnn_rnn_ops_py", srcs = [ "__init__.py", "python/ops/cudnn_rnn_ops.py", @@ -81,10 +81,67 @@ tf_custom_op_py_library( ], ) +tf_custom_op_py_library( + name = "cudnn_rnn_py", + srcs = [ + "__init__.py", + "python/layers/cudnn_rnn.py", + ], + dso = [ + ":python/ops/_cudnn_rnn_ops.so", + ], + kernels = [ + ":cudnn_rnn_kernels", + ":cudnn_rnn_ops_op_lib", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":cudnn_rnn_ops", + ":cudnn_rnn_ops_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + ], +) + cuda_py_test( name = "cudnn_rnn_ops_test", size = "large", srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], + additional_deps = [ + ":cudnn_rnn_ops_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python/ops/losses:losses", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], + shard_count = 6, + tags = [ + "manual", + "requires_cudnn5", + ], +) + +cuda_py_test( + name = "cudnn_rnn_test", + size = "large", + srcs = ["python/kernel_tests/cudnn_rnn_test.py"], additional_deps = [ ":cudnn_rnn_py", "//tensorflow/core:protos_all_py", @@ -114,7 +171,7 @@ cuda_py_test( size = "large", srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"], additional_deps = [ - ":cudnn_rnn_py", + ":cudnn_rnn_ops_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce8954bb09d7444a552d0ba6b3d9bb72cd919fd --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -0,0 +1,1050 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Cudnn RNN models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import os +import unittest + +import numpy as np + +from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn +from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops +from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import gradients_impl as gradients +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn as rnn_lib +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import saver as saver_lib + +CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM +CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU +CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU +CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH +CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION +CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION + +CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER +CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER +CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER +CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER + + +class CudnnTestModel(object): + """Model with convenient APIs for easier building and running test graph. + + The graph built is used by all tests below to avoid repeatedly building + similar test graphs. + """ + + def __init__(self, + rnn_mode, + num_layers, + num_units, + input_size, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + dtype=dtypes.float32, + training=False, + kernel_initializer=None, + bias_initializer=None): + if dtype not in (dtypes.float32, dtypes.float64): + raise ValueError("Invalid dtype: %s" % dtype) + self._dtype = dtype + + self._inputs = array_ops.placeholder( + dtype=dtype, shape=[None, None, input_size], name="inputs") + h = array_ops.placeholder( + dtype=dtype, shape=[None, None, num_units], name="h") + c = array_ops.placeholder( + dtype=dtype, shape=[None, None, num_units], name="c") + if rnn_mode == CUDNN_LSTM: + model_fn = cudnn_rnn.CudnnLSTM + self._initial_state = (h, c) + elif rnn_mode == CUDNN_GRU: + model_fn = cudnn_rnn.CudnnGRU + self._initial_state = (h,) + elif rnn_mode == CUDNN_RNN_TANH: + model_fn = cudnn_rnn.CudnnRNNTanh + self._initial_state = (h,) + elif rnn_mode == CUDNN_RNN_RELU: + model_fn = cudnn_rnn.CudnnRNNRelu + self._initial_state = (h,) + else: + raise ValueError("Invalid rnn_mode: %s" % rnn_mode) + self._rnn = model_fn( + num_layers, + num_units, + direction=direction, + dropout=dropout, + dtype=dtype, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer) + self._rnn.build([None, None, input_size]) + + self._outputs, self._output_state = self._rnn( + self._inputs, initial_state=self._initial_state, training=training) + + def _AddUp(self, outputs, output_state): + total = math_ops.reduce_sum(outputs) + for s in output_state: + total += math_ops.reduce_sum(s) + return total + + @property + def inputs(self): + return self._inputs + + @property + def initial_state(self): + return self._initial_state + + @property + def outputs(self): + return self._outputs + + @property + def output_state(self): + return self._output_state + + @property + def rnn(self): + return self._rnn + + @property + def total_sum(self): + return self._AddUp(self.outputs, self.output_state) + + def SynthesizeInput(self, seq_length, batch_size, seed=1234): + """Synthesizes input and initial state values for testing.""" + np.random.seed(seed) + num_layers = self._rnn.num_layers + dir_count = self._rnn.num_dirs + num_units = self._rnn.num_units + input_size = self._rnn.input_size + + np_dtype = np.float32 if self._dtype == dtypes.float32 else np.float64 + inputs = np.random.randn(seq_length, batch_size, + input_size).astype(np_dtype) + input_h = np.random.randn(num_layers * dir_count, batch_size, + num_units).astype(np_dtype) + if self._rnn.rnn_mode == CUDNN_LSTM: + input_c = np.random.randn(num_layers * dir_count, batch_size, + num_units).astype(np_dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + return inputs, initial_state + + def ZeroState(self, batch_size): + num_layers = self._rnn.num_layers + dir_count = self._rnn.num_dirs + num_units = self._rnn.num_units + + np_dtype = np.float32 if self._dtype == dtypes.float32 else np.float64 + input_h = np.zeros((num_layers * dir_count, batch_size, + num_units)).astype(np_dtype) + if self._rnn.rnn_mode == CUDNN_LSTM: + input_c = np.zeros((num_layers * dir_count, batch_size, + num_units)).astype(np_dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + return initial_state + + def FProp(self, inputs_t, initial_state_t, training): + """Builds additional subgraph with given inputs and state. + + Args: + inputs_t: a tensor. + initial_state_t: a tensor. + training: boolean, true if training mode. + Returns: + A tensor of the forward pass output of the model. + """ + outputs, output_state = self._rnn( + inputs_t, initial_state=initial_state_t, training=training) + return self._AddUp(outputs, output_state) + + def Feed(self, sess, inputs, initial_state=None, return_sum=True): + """Runs graph with given inputs and initial state.""" + batch_size = inputs.shape[1] + if initial_state is None: + initial_state = self.ZeroState(batch_size) + if return_sum: + return sess.run( + self.total_sum, + feed_dict={self.inputs: inputs, + self.initial_state: initial_state}) + else: + return sess.run( + [self.outputs, self.output_state], + feed_dict={self.inputs: inputs, + self.initial_state: initial_state}) + + +def _CreateCudnnCompatibleCanonicalRNN(rnn, inputs, is_bidi=False, scope=None): + mode = rnn.rnn_mode + num_units = rnn.num_units + num_layers = rnn.num_layers + + # To reuse cuDNN-trained models, must use cudnn compatible rnn cells. + if mode == CUDNN_LSTM: + single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleLSTMCell(num_units) + elif mode == CUDNN_GRU: + single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units) + elif mode == CUDNN_RNN_TANH: + single_cell = (lambda: rnn_cell_impl.BasicRNNCell(num_units, math_ops.tanh)) + elif mode == CUDNN_RNN_RELU: + single_cell = ( + lambda: rnn_cell_impl.BasicRNNCell(num_units, gen_nn_ops.relu)) + else: + raise ValueError("%s is not supported!" % mode) + + if not is_bidi: + cell = rnn_cell_impl.MultiRNNCell( + [single_cell() for _ in range(num_layers)]) + return rnn_lib.dynamic_rnn( + cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope) + else: + cells_fw = [single_cell() for _ in range(num_layers)] + cells_bw = [single_cell() for _ in range(num_layers)] + + (outputs, output_state_fw, + output_state_bw) = contrib_rnn_lib.stack_bidirectional_dynamic_rnn( + cells_fw, + cells_bw, + inputs, + dtype=dtypes.float32, + time_major=True, + scope=scope) + return outputs, (output_state_fw, output_state_bw) + + +class CudnnRNNTestBasic(TensorFlowTestCase): + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testLayerBasic(self): + num_layers = 4 + num_units = 2 + batch_size = 8 + direction = CUDNN_RNN_UNIDIRECTION + dir_count = 1 + + with vs.variable_scope("main"): + kernel_initializer = init_ops.constant_initializer(0.) + bias_initializer = init_ops.constant_initializer(0.) + inputs = random_ops.random_uniform([ + num_layers * dir_count, batch_size, num_units], dtype=dtypes.float32) + + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + + # Build the layer + outputs1, _ = lstm(inputs) + # Reuse the layer + outputs2, _ = lstm(inputs) + + total_sum1 = math_ops.reduce_sum(outputs1) + total_sum2 = math_ops.reduce_sum(outputs2) + + with vs.variable_scope("main", reuse=True): + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + + # Reuse the layer + outputs3, _ = lstm(inputs) + total_sum3 = math_ops.reduce_sum(outputs3) + + self.assertEqual(1, len(variables.trainable_variables())) + self.assertEqual(1, len(ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))) + self.assertEqual("main/awesome_lstm/opaque_kernel", + variables.trainable_variables()[0].op.name) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + (total_sum1_v, total_sum2_v, total_sum3_v) = sess.run( + [total_sum1, total_sum2, total_sum3]) + self.assertEqual(0, total_sum1_v) + self.assertEqual(0, total_sum2_v) + self.assertEqual(0, total_sum3_v) + + +# TODO(jamesqin): Transform to parameterized test after it is included in the +# TF open source codebase. +class CudnnRNNTestSaveRestore(TensorFlowTestCase): + + def _CompareWeights(self, lhs, rhs): + self.assertEqual(len(lhs), len(rhs)) + for lw, rw in zip(lhs, rhs): + self.assertAllEqual(lw, rw) + + def _CompareBiases(self, lhs, rhs, rnn_mode, num_layers, direction): + self.assertEqual(len(lhs), len(rhs)) + if rnn_mode == CUDNN_LSTM: + num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_GRU: + num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_TANH: + num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER + else: + num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2 + num_params_per_layer *= num_dirs + self.assertEqual(num_params_per_layer * num_layers, len(lhs)) + + for i in range(num_layers): + layer_lhs = lhs[i * num_params_per_layer: (i+1) * num_params_per_layer] + layer_rhs = rhs[i * num_params_per_layer: (i+1) * num_params_per_layer] + if direction == CUDNN_RNN_UNIDIRECTION: + self._CompareSingleLayerBiases(layer_lhs, layer_rhs) + else: + size = len(layer_lhs) + fw_lhs, bw_lhs = layer_lhs[:size//2], layer_lhs[size//2:] + fw_rhs, bw_rhs = layer_rhs[:size//2], layer_rhs[size//2:] + self._CompareSingleLayerBiases(fw_lhs, fw_rhs) + self._CompareSingleLayerBiases(bw_lhs, bw_rhs) + + def _CompareSingleLayerBiases(self, lhs, rhs): + self.assertEqual(len(lhs), len(rhs)) + + lf_lhs, rt_lhs = lhs[:len(lhs)//2], lhs[len(lhs)//2:] + lf_rhs, rt_rhs = rhs[:len(rhs)//2], rhs[len(rhs)//2:] + self.assertEqual(len(lf_lhs), len(rt_lhs)) + self.assertEqual(len(lf_rhs), len(rt_rhs)) + + sum_lhs, sum_rhs = [], [] + for lf, rt in zip(lf_lhs, rt_lhs): + sum_lhs.append(lf + rt) + for lf, rt in zip(lf_rhs, rt_rhs): + sum_rhs.append(lf + rt) + self.assertEqual(len(sum_lhs), len(sum_rhs)) + for lf, rt in zip(sum_lhs, sum_rhs): + self.assertAllEqual(lf, rt) + + def _TestSaveRestoreVariable(self, rnn_mode, direction, dtype): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default() as g: + random_seed.set_random_seed(1234) + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + rnn = model.rnn + save_path = os.path.join(self.get_temp_dir(), + "save-restore-variable-test") + saver = saver_lib.Saver() + weights, biases = model.rnn.saveable._OpaqueParamsToCanonical() + opaque_params = rnn.trainable_variables[0] + # CudnnTestModel() creates CudnnOpaqueParamsSaveable that helps saver save + # Cudnn vars in canonical format. + reset_op = state_ops.assign( + opaque_params, + array_ops.zeros(array_ops.shape(opaque_params), dtype=dtype)) + # Passing graph explicitly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + weights_v, biases_v = sess.run([weights, biases]) + + # Reset opaque param + sess.run(reset_op) + saver.restore(sess, save_path) + weights_v_restored, biases_v_restored = sess.run([weights, biases]) + + self._CompareWeights(weights_v, weights_v_restored) + self._CompareBiases(biases_v, biases_v_restored, rnn_mode, num_layers, + direction) + + def _TestSaveRestoreTwoVariables(self, rnn_mode, direction, dtype): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default() as g: + random_seed.set_random_seed(1234) + with vs.variable_scope("m1"): + model1 = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + with vs.variable_scope("m2"): + model2 = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + opaque_params = (model1.rnn.trainable_variables[0], + model2.rnn.trainable_variables[0]) + weights1, biases1 = model1.rnn.saveable._OpaqueParamsToCanonical() + weights2, biases2 = model2.rnn.saveable._OpaqueParamsToCanonical() + reset_params = [ + state_ops.assign(params, + array_ops.zeros_like(params, dtype=dtype)) + for params in opaque_params + ] + reset_op = control_flow_ops.group(*reset_params) + save_path = os.path.join(self.get_temp_dir(), + "save-restore-variable-test2") + saver = saver_lib.Saver() + # Passing graph explicitly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + + weights1_v, biases1_v = sess.run([weights1, biases1]) + weights2_v, biases2_v = sess.run([weights2, biases2]) + + sess.run(reset_op) + saver.restore(sess, save_path) + weights1_v_restored, biases1_v_restored = sess.run([weights1, biases1]) + weights2_v_restored, biases2_v_restored = sess.run([weights2, biases2]) + + self._CompareWeights(weights1_v, weights1_v_restored) + self._CompareWeights(weights2_v, weights2_v_restored) + self._CompareBiases(biases1_v, biases1_v_restored, rnn_mode, num_layers, + direction) + self._CompareBiases(biases2_v, biases2_v_restored, rnn_mode, num_layers, + direction) + + def _TestSaveRestoreOutput(self, rnn_mode, direction, dtype): + with ops.Graph().as_default() as g: + num_layers = 2 + num_units = 7 + input_size = 7 + seq_length = 8 + batch_size = 4 + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=False) + rnn = model.rnn + + save_path = os.path.join(self.get_temp_dir(), "save-restore-output-test") + saver = saver_lib.Saver() + + # Only one opaque var in a cudnn layer. + assert len(rnn.trainable_variables) == 1 + reset_params = state_ops.assign( + rnn.trainable_variables[0], + array_ops.zeros( + array_ops.shape(rnn.trainable_variables[0]), dtype=dtype)) + + # Passing graph explicitly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + inputs, initial_state = model.SynthesizeInput(seq_length, batch_size) + total_sum_v = model.Feed(sess, inputs, initial_state) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + + sess.run(reset_params) + saver.restore(sess, save_path) + total_sum_v_restored = model.Feed(sess, inputs, initial_state) + self.assertAllClose(total_sum_v, total_sum_v_restored, atol=1e-5) + + def _TestSaveRestoreHelper(self, rnn_mode): + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + dtype_list = [dtypes.float32, dtypes.float64] + for direction, dtype in itertools.product(directions, dtype_list): + self._TestSaveRestoreVariable(rnn_mode, direction, dtype) + self._TestSaveRestoreTwoVariables(rnn_mode, direction, dtype) + self._TestSaveRestoreOutput(rnn_mode, direction, dtype) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRepeatedlyCreateCustomSaveable(self): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default(): + random_seed.set_random_seed(1234) + model = CudnnTestModel( + CUDNN_LSTM, + num_layers, + num_units, + input_size, + direction=CUDNN_RNN_UNIDIRECTION, + dtype=dtypes.float32) + with self.assertRaisesRegexp(RuntimeError, + "Cudnn saveable already created"): + model.rnn._create_saveable() + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreLSTM(self): + self._TestSaveRestoreHelper(CUDNN_LSTM) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreGRU(self): + self._TestSaveRestoreHelper(CUDNN_GRU) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRNNTanh(self): + self._TestSaveRestoreHelper(CUDNN_RNN_TANH) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRNNRelu(self): + self._TestSaveRestoreHelper(CUDNN_RNN_RELU) + + +# TODO(jamesqin): Transform to parameterized test after it is included in the +# TF open source codebase. +class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase): + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleLSTM(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_LSTM) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleGRU(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_GRU) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleRNNTanh(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_RNN_TANH) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleRNNRelu(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_RNN_RELU) + + def _TestCudnnCompatibleRnnCellsHelper(self, rnn_mode): + configs = [ + { + "num_layers": 1, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + }, + { + "num_layers": 2, + "seq_length": 8, + "num_units": 4, + "input_size": 8, + "batch_size": 16, + }, + { + "num_layers": 2, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + }, + { + "num_layers": 1, + "seq_length": 2, + "num_units": 2, + "input_size": 4, + "batch_size": 1, + }, + ] + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + for cfg, direction in zip(configs, directions): + self._TestCudnnCompatibleRnnCells(cfg["num_layers"], cfg["seq_length"], + cfg["num_units"], cfg["input_size"], + cfg["batch_size"], rnn_mode, direction) + + def _TestCudnnCompatibleRnnCells(self, num_layers, seq_length, num_units, + input_size, batch_size, rnn_mode, direction): + dtype = dtypes.float32 + # Train graph + with ops.Graph().as_default() as g: + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=True) + target_output = array_ops.placeholder(dtype=dtype) + loss_op = losses.log_loss( + labels=target_output, predictions=model.total_sum) + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2) + train_op = optimizer.minimize(loss_op) + + saver = saver_lib.Saver() + + # Train Cudnn model + seed = 0 + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + # Train 128 steps + num_steps = 128 + for _ in range(num_steps): + inputs, _ = model.SynthesizeInput(seq_length, batch_size, seed) + targets = np.random.rand() + sess.run( + train_op, + feed_dict={ + model.inputs: inputs, + model.initial_state: model.ZeroState(batch_size), + target_output: targets + }) + seed += 1 + + save_path = os.path.join(self.get_temp_dir(), + ("cudnn-rnn-%s-test" % rnn_mode)) + save_v = saver.save(sess, save_path) + self.assertEqual(save_path, save_v) + + # Cudnn inference graph + with ops.Graph().as_default() as g: + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=False) + rnn = model.rnn + saver = saver_lib.Saver() + + inference_input = np.random.rand(seq_length, batch_size, + input_size).astype(np.float32) + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + saver.restore(sess, save_path) + + # Cudnn inference + cudnn_outputs_v, cudnn_output_states_v = model.Feed( + sess, inference_input, return_sum=False) + + # Canonical RNN inference graph + with ops.Graph().as_default() as g: + cell_inputs = array_ops.placeholder( + dtype, shape=[seq_length, batch_size, input_size]) + if direction == CUDNN_RNN_UNIDIRECTION: + # outputs is one tensor, states are num_layer tuples, each 2 tensors + (outputs, states) = _CreateCudnnCompatibleCanonicalRNN(rnn, cell_inputs) + if rnn_mode == CUDNN_LSTM: + output_h = array_ops.stack([s.h for s in states]) + output_c = array_ops.stack([s.c for s in states]) + else: + output_state = array_ops.stack([s for s in states]) + else: + # outputs is one tensor. + # states is a tuple of 2 tuples: + # each sub tuple is num_layer tuples, each with 2 tensors. + (outputs, states) = _CreateCudnnCompatibleCanonicalRNN( + rnn, cell_inputs, is_bidi=True) + output_state_fw, output_state_bw = states + if rnn_mode == CUDNN_LSTM: + output_h, output_c = [], [] + for s_fw, s_bw in zip(output_state_fw, output_state_bw): + output_h.append(array_ops.stack([s_fw.h, s_bw.h])) + output_c.append(array_ops.stack([s_fw.c, s_bw.c])) + output_h = array_ops.concat(output_h, axis=0) + output_c = array_ops.concat(output_c, axis=0) + else: + output_state = [] + for s_fw, s_bw in zip(output_state_fw, output_state_bw): + output_state.append(array_ops.stack([s_fw, s_bw])) + output_state = array_ops.concat(output_state, axis=0) + saver = saver_lib.Saver() + + with self.test_session(use_gpu=True, graph=g) as sess: + saver.restore(sess, save_path) + + # BlockCell inference + if rnn_mode == CUDNN_LSTM: + outputs_v, output_h_v, output_c_v = sess.run( + [outputs, output_h, output_c], + feed_dict={cell_inputs: inference_input}) + self.assertAllClose(cudnn_outputs_v, outputs_v) + cudnn_output_h_v, cudnn_output_c_v = cudnn_output_states_v + self.assertAllClose(cudnn_output_h_v, output_h_v) + self.assertAllClose(cudnn_output_c_v, output_c_v) + else: + outputs_v, output_state_v = sess.run( + [outputs, output_state], + feed_dict={cell_inputs: inference_input}) + self.assertAllClose(cudnn_outputs_v, outputs_v, atol=1e-5, rtol=1e-5) + (cudnn_output_h_v,) = cudnn_output_states_v + self.assertAllClose(cudnn_output_h_v, output_state_v, atol=1e-5, + rtol=1e-5) + + +class CudnnRNNTestParamsSize(TensorFlowTestCase): + + def _TestOpaqueParamsSize(self, rnn_mode, num_layers, num_units, input_size, + direction): + logging.info("Testing one lstm param size with config: %s", locals()) + dtype = dtypes.float32 + + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + dtype=dtype, + direction=direction) + rnn = model.rnn + + # Min param size estimate = sum(weights.size) + sum(biases.size) + min_params_size = ( + np.sum(map(np.prod, rnn.canonical_weight_shapes)) + + np.sum([sp[0] for sp in rnn.canonical_bias_shapes])) + + opaque_params = rnn.trainable_variables[0] + with self.test_session(use_gpu=True, graph=ops.get_default_graph()): + variables.global_variables_initializer().run() + opaque_params_size_v = opaque_params.eval().size + self.assertLessEqual(min_params_size, opaque_params_size_v) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testOpaqueParamsSize(self): + test_configs = [ + [4, 200, 200], + [4, 200, 300], + [4, 200, 100], + [1, 100, 200], + [2, 200, 100], + [3, 200, 400], + ] + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + rnns = [CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH] + for (rnn, config, direction) in itertools.product(rnns, test_configs, + directions): + num_layers, num_units, input_size = config + with ops.Graph().as_default(): + self._TestOpaqueParamsSize(rnn, num_layers, num_units, input_size, + direction) + + +class CudnnRNNTestTraining(TensorFlowTestCase): + + def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1): + """Compute the numeric gradient of y wrt to x. + + Args: + sess: The TF session constructed with a graph containing x and y. + y: A scalar TF Tensor in the graph constructed in sess. + x: A TF Tensor in the graph constructed in sess. + delta: Gradient checker's small perturbation of x[i]. + step: Only compute numerical gradients for a subset of x values. + I.e. dy/dx[i] is computed if i % step == 0. + Returns: + A Tensor of the same shape and dtype as x. If x[i] is not chosen + to compute the numerical gradient dy/x[i], the corresponding + value is set to 0. + """ + + x_data = sess.run(x) + x_size = x_data.size + x_shape = x_data.shape + + numeric_grad = np.zeros(x_size, dtype=x_data.dtype) + + for i in range(0, x_size, step): + x_pos = x_data.copy() + if x_size == 1: + x_pos += delta + else: + x_pos.flat[i] += delta + y_pos_feed_dict = dict([(x.name, x_pos)]) + y_pos = sess.run(y, feed_dict=y_pos_feed_dict) + + x_neg = x_data.copy() + if x_size == 1: + x_neg -= delta + else: + x_neg.flat[i] -= delta + y_neg_feed_dict = dict([(x.name, x_neg)]) + y_neg = sess.run(y, feed_dict=y_neg_feed_dict) + numeric_grad[i] = (y_pos - y_neg) / (2 * delta) + return numeric_grad.reshape(x_shape) + + def _GradientCheck(self, sess, y, xs, tolerance=1e-6, delta=1e-4): + sym_grads_t = gradients.gradients(y, xs) + sym_grads = sess.run(sym_grads_t) + + num_grads = [self._ComputeNumericGrad(sess, y, x, delta) for x in xs] + self.assertEqual(len(sym_grads), len(num_grads)) + for sym, num in zip(sym_grads, num_grads): + self.assertFalse(np.any(np.isnan(sym))) + self.assertFalse(np.any(np.isnan(num))) + self.assertAllClose(sym, num, atol=tolerance, rtol=tolerance) + + def _TestOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, + batch_size, seq_length, dir_count, dropout, dtype, + delta, tolerance): + # Gradient checking runs two forward ops with almost the same input. Need to + # make sure the drop patterns across the two runs are the same. + logging.info("Training test with config: %s", locals()) + old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) + os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) + random_seed.set_random_seed(5678) + has_input_c = (rnn_mode == CUDNN_LSTM) + direction = (CUDNN_RNN_UNIDIRECTION + if dir_count == 1 else CUDNN_RNN_BIDIRECTION) + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dropout=dropout, + dtype=dtype, + training=True, + bias_initializer=init_ops.random_normal_initializer( + mean=1., dtype=dtype)) + rnn = model.rnn + params = rnn.trainable_variables[0] + + inputs = variables.Variable( + random_ops.random_uniform( + [seq_length, batch_size, input_size], dtype=dtype), + dtype=dtype) + input_h = variables.Variable( + random_ops.random_uniform( + [num_layers * dir_count, batch_size, num_units], dtype=dtype), + dtype=dtype) + if has_input_c: + input_c = variables.Variable( + random_ops.random_uniform( + [num_layers * dir_count, batch_size, num_units], dtype=dtype), + dtype=dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + total_sum = model.FProp(inputs, initial_state, training=True) + + with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: + sess.run(variables.global_variables_initializer()) + all_inputs = [inputs, params] + for s in initial_state: + all_inputs.append(s) + self._GradientCheck( + sess, total_sum, all_inputs, tolerance=tolerance, delta=delta) + os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state + + def _TestSimpleTrainingHelper(self, rnn_mode, test_configs): + dropouts = [0., 0.5, 1.] + for config, dropout in itertools.product(test_configs, dropouts): + dtype = config.get("dtype", dtypes.float32) + delta = config.get("delta", 1e-4) + tolerance = config.get("tolerance", 1e-6) + dir_count = config.get("dir_count", 1) + shape = config["shape"] + with ops.Graph().as_default(): + self._TestOneSimpleTraining(rnn_mode, shape["num_layers"], + shape["num_units"], shape["input_size"], + shape["batch_size"], shape["seq_length"], + dir_count, dropout, dtype, delta, tolerance) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingLSTM64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingLSTM32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-4, + "tolerance": 9e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingGRU64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + } + }, + ] + self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingGRU32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-3, + "tolerance": 4e-3, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNTanh64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNTanh32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-3, + "tolerance": 5e-3, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNRelu64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNRelu32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-3, + "tolerance": 7e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..f6c206022c68c4ba78d895f44288f4b180d199c0 --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -0,0 +1,556 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cudnn RNN operators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops +from tensorflow.contrib.util import loader +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as base_layer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging + +_cudnn_rnn_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so")) + +CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION +CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION +CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM +CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU +CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU +CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH + +# Half for cell input, half for hidden states. +CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER +CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER +CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER +CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER + +CUDNN_INPUT_LINEAR_MODE = cudnn_rnn_ops.CUDNN_INPUT_LINEAR_MODE +CUDNN_INPUT_SKIP_MODE = cudnn_rnn_ops.CUDNN_INPUT_SKIP_MODE +CUDNN_INPUT_AUTO_MODE = cudnn_rnn_ops.CUDNN_INPUT_AUTO_MODE + + +class _CudnnRNN(base_layer.Layer): + # pylint:disable=line-too-long + """Abstract class for RNN layers with Cudnn implementation. + + Cudnn RNNs have two major differences from other platform-independent RNNs tf + provides: + * Cudnn LSTM and GRU are mathematically different from their tf counterparts. + (e.g. @{tf.contrib.rnn.LSTMBlockCell} and @{tf.nn.rnn_cell.GRUCell}. + * Cudnn-trained checkpoints are not directly compatible with tf RNNs: + * They use a single opaque parameter buffer for the entire (possibly) + multi-layer multi-directional RNN; Whereas tf RNN weights are per-cell and + layer. + * The size and layout of the parameter buffers may change between + CUDA/CuDNN/GPU generations. Because of that, the opaque parameter variable + does not have a static shape and is not partitionable. Instead of using + partitioning to alleviate the PS's traffic load, try building a + multi-tower model and do gradient aggregation locally within the host + before updating the PS. See https://www.tensorflow.org/performance/performance_models#parameter_server_variables + for a detailed performance guide. + + Consequently, if one plans to use Cudnn trained models on both GPU and CPU + for inference and training, one needs to: + * Create a CudnnOpaqueParamsSaveable subclass object to save RNN params in + canonical format. (This is done for you automatically during layer building + process.) + * When not using a Cudnn RNN class, use CudnnCompatibleRNN classes to load the + checkpoints. These classes are platform-independent and perform the same + computation as Cudnn for training and inference. + Similarly, CudnnCompatibleRNN-trained checkpoints can be loaded by CudnnRNN + classes seamlessly. + + Below is a typical workflow(using LSTM as an example): + for detailed performance guide. + + # Use Cudnn-trained checkpoints with CudnnCompatibleRNNs + ```python + with tf.Graph().as_default(): + lstm = CudnnLSTM(num_layers, num_units, direction, ...) + + outputs, output_states = lstm(inputs, initial_states, training=True) + + # If user plans to delay calling the cell with inputs, one can do + # lstm.build(input_shape) + + saver = Saver() + + # training subgraph + ... + + # Once in a while save the model. + saver.save(save_path) + + # Inference subgraph for unidirectional RNN on, e.g., CPU or mobile. + with tf.Graph().as_default(): + single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units) + + # NOTE: Even if there's only one layer, the cell needs to be wrapped in + # MultiRNNCell. + cell = tf.nn.rnn_cell.MultiRNNCell( + [single_cell() for _ in range(num_layers)]) + + # Leave the scope arg unset. + outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state, ...) + + saver = Saver() + + # Create session + sess = ... + + # Restores + saver.restore(sess, save_path) + + # Inference subgraph for bidirectional RNN + with tf.Graph().as_default(): + single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units) + cells_fw = [single_cell() for _ in range(num_layers)] + cells_bw = [single_cell() for _ in range(num_layers)] + + # Leave the scope arg unset. + (outputs, output_state_fw, + output_state_bw) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( + cells_fw, cells_bw, inputs, ...) + saver = Saver() + + # Create session + sess = ... + + # Restores + saver.restore(sess, save_path) + ``` + """ + # pylint:enable=line-too-long + + # The following are constants defined by subclasses. + # Type of RNN cell. + _rnn_mode = None + # Number of cell weights(or biases) per layer. + _num_params_per_layer = None + # Custom SaveableObject class for the CudnnRNN class. + _saveable_cls = None + + # TODO(jamesqin): support float16 CuDNN RNN + def __init__(self, + num_layers, + num_units, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + seed=None, + dtype=dtypes.float32, + kernel_initializer=None, + bias_initializer=None, + name=None): + """Creates a CudnnRNN model from model spec. + + Args: + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It can be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Can be either + 'unidirectional' or 'bidirectional' + dropout: dropout rate, a number between [0, 1]. Dropout is applied on + inputs of each layer. When set to 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + dtype: tf.float32 or tf.float64 + kernel_initializer: starting value to initialize the weight. + bias_initializer: starting value to initialize the bias + (default is all zeros). + name: VariableScope for the created subgraph; defaults to class name. + This only serves the default scope if later no scope is specified when + invoking __call__(). + + Raises: + ValueError: if direction is invalid. Or dtype is not supported. + """ + super(_CudnnRNN, self).__init__(dtype=dtype, name=name) + cudnn_rnn_ops.check_direction(direction) + cudnn_rnn_ops.check_input_mode(input_mode) + + if dtype not in [dtypes.float32, dtypes.float64]: + raise ValueError("Only support float32, float64, provided %s" % dtype) + # Layer self.dtype is type name, the original DType object is kept here. + self._plain_dtype = dtype + self._num_layers = num_layers + self._num_units = num_units + self._input_mode = input_mode + self._direction = direction + self._dropout = dropout + self._seed = seed + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + # Init input_size to None, which will be set after build(). + self._input_size = None + self._saveable = None + + @property + def num_layers(self): + return self._num_layers + + @property + def num_units(self): + return self._num_units + + @property + def input_mode(self): + """Input mode of first layer. + + Indicates whether there is a linear projection between the input and the + actual computation before the first layer. It can be + * 'linear_input': (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior) + * 'skip_input': 'skip_input' is only allowed when input_size == num_units. + * 'auto_select'. implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + + Returns: + 'linear_input', 'skip_input' or 'auto_select'. + """ + return self._input_mode + + @property + def input_size(self): + if not self._input_size: + raise ValueError( + "\'input_size\' is unknown since layer has not been built.") + return self._input_size + + @property + def rnn_mode(self): + """Type of RNN cell used. + + Returns: + `lstm`, `gru`, `rnn_relu` or `rnn_tanh`. + """ + return self._rnn_mode + + @property + def direction(self): + """Returns `unidirectional` or `bidirectional`.""" + return self._direction + + @property + def num_dirs(self): + return 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + + @property + def saveable(self): + return self._saveable + + @property + def canonical_weight_shapes(self): + """Shapes of Cudnn canonical weight tensors.""" + if not self._input_size: + raise RuntimeError( + "%s.canonical_weight_shapes invoked before input shape is known" % + type(self).__name__) + + shapes = [] + for i in range(self._num_layers): + shapes.extend(self._canonical_weight_shape(i)) + return shapes + + @property + def canonical_bias_shapes(self): + """Shapes of Cudnn canonical bias tensors.""" + return self._canonical_bias_shape(0) * self._num_layers + + def _update_trainable_weights(self, getter, *args, **kwargs): + """Custom getter for layer variables.""" + # Add variables to layer's `(non_)trainable_weights` list(s). + variable = getter(*args, **kwargs) + trainable = kwargs.get("trainable", True) + if trainable and variable not in self._trainable_weights: + self._trainable_weights.append(variable) + elif not trainable and variable not in self._non_trainable_weights: + self._non_trainable_weights.append(variable) + return variable + + def build(self, input_shape): + """Create variables of the Cudnn RNN. + + It can be called manually before `__call__()` or automatically through + `__call__()`. In the former case, subsequent `__call__()`s will skip + creating variables. + Args: + input_shape: network input tensor shape, a python list or a TensorShape + object with 3 dimensions. + Raises: + ValueError: if input_shape has wrong dimension or unknown 3rd dimension. + """ + if self.built: + return + + input_shape = tensor_shape.TensorShape(input_shape) + if input_shape.ndims != 3: + raise ValueError("Expecting input_shape with 3 dims, got %d" % + input_shape.ndims) + if input_shape[-1].value is None: + raise ValueError("The last dimension of the inputs to `CudnnRNN` " + "should be defined. Found `None`.") + self._input_size = input_shape[-1].value + self.input_spec = base_layer.InputSpec(ndim=3, axes={-1: self._input_size}) + + self._set_scope(None) + + # Not using base class `add_variable()` since the it calls + # `tf.get_variable()` with a callable initializer whereas here with a + # tensor. The difference is mandated to support forward-compatibility with + # Cudnn. + with vs.variable_scope( + self._scope, + reuse=self.built, + custom_getter=self._update_trainable_weights): + if self._kernel_initializer is None: + self._kernel_initializer = init_ops.glorot_uniform_initializer( + seed=self._seed, dtype=self._plain_dtype) + if self._bias_initializer is None: + self._bias_initializer = init_ops.constant_initializer( + 0.0, dtype=self._plain_dtype) + + weights = [ + self._kernel_initializer(sp, dtype=self._plain_dtype) + for sp in self.canonical_weight_shapes + ] + biases = [ + self._bias_initializer(sp, dtype=self._plain_dtype) + for sp in self.canonical_bias_shapes + ] + opaque_params_t = self._canonical_to_opaque(weights, biases) + + if vs.get_variable_scope().partitioner is not None: + logging.warn( + "Partitioner is not supported for Cudnn RNN layer variables, using " + "it will create forward-compatibility issues with future " + "CUDA/CuDNN generations.") + # Initialize opaque params with a tensor. + self.kernel = vs.get_variable( + "opaque_kernel", initializer=opaque_params_t, validate_shape=False) + # Create saveable in the outer scope of the cudnn subgraph, such that + # alternative subgraph with platform-independent rnn cells can load the + # checkpoints directly. + if not (self.built or vs.get_variable_scope().reuse): + self._create_saveable() + self.built = True + + def call(self, inputs, initial_state=None, training=True): + """Runs the forward step for the RNN model. + + Args: + inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`. + initial_state: a tuple of tensor(s) of shape + `[num_layers * num_dirs, batch_size, num_units]`. If not provided, use + zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs. + training: whether this operation will be used in training or inference. + Returns: + output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`. + It is a `concat([fwd_output, bak_output], axis=2)`. + output_states: a tuple of tensor(s) of the same shape and structure as + `initial_state`. + Raises: + ValueError: initial_state is not a tuple. + """ + if initial_state is not None and not isinstance(initial_state, tuple): + raise ValueError("Invalid initial_state type: %s, expecting tuple.", + type(initial_state)) + dtype = self.dtype + inputs = ops.convert_to_tensor(inputs, dtype=dtype) + + batch_size = array_ops.shape(inputs)[1] + if initial_state is None: + initial_state = self._zero_state(batch_size) + if self._rnn_mode == CUDNN_LSTM: + h, c = initial_state # pylint:disable=unbalanced-tuple-unpacking,unpacking-non-sequence + else: + h, = initial_state # pylint:disable=unbalanced-tuple-unpacking,unpacking-non-sequence + h = ops.convert_to_tensor(h, dtype=dtype) + if self._rnn_mode == CUDNN_LSTM: + c = ops.convert_to_tensor(c, dtype=dtype) + else: + # For model that doesn't take input_c, replace with a dummy tensor. + c = array_ops.constant([], dtype=dtype) + outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel, + training) + if self._rnn_mode == CUDNN_LSTM: + return outputs, (output_h, output_c) + else: + return outputs, (output_h,) + + def state_shape(self, batch_size): + raise NotImplementedError + + def _zero_state(self, batch_size): + res = [] + for sp in self.state_shape(batch_size): + res.append(array_ops.zeros(sp, dtype=self.dtype)) + return tuple(res) + + def _canonical_weight_shape(self, layer): + """Shapes of Cudnn canonical weight tensors for given layer.""" + if layer < 0 or layer >= self._num_layers: + raise ValueError("\'layer\' is not valid, got %s, expecting [%d, %d]" % + (layer, 0, self._num_layers-1)) + if not self._input_size: + raise RuntimeError( + "%s._canonical_weight_shape invoked before input shape is known" % + type(self).__name__) + + input_size = self._input_size + num_units = self._num_units + num_gates = self._num_params_per_layer // 2 + is_bidi = self._direction == CUDNN_RNN_BIDIRECTION + + if layer == 0: + wts_applied_on_inputs = [(num_units, input_size)] * num_gates + else: + if is_bidi: + wts_applied_on_inputs = [(num_units, 2 * num_units)] * num_gates + else: + wts_applied_on_inputs = [(num_units, num_units)] * num_gates + wts_applied_on_hidden_states = [(num_units, num_units)] * num_gates + tf_wts = wts_applied_on_inputs + wts_applied_on_hidden_states + return tf_wts if not is_bidi else tf_wts * 2 + + def _canonical_bias_shape(self, unused_layer): + """Shapes of Cudnn canonical bias tensors for given layer.""" + num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + return [[self._num_units]] * num_dirs * self._num_params_per_layer + + def _canonical_to_opaque(self, cu_weights, cu_biases): + if not self._input_size: + raise RuntimeError( + "%s._canonical_to_opaque invoked before input shape is known" % + type(self).__name__) + return cudnn_rnn_ops.cudnn_rnn_canonical_to_opaque_params( + rnn_mode=self._rnn_mode, + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + input_mode=self._input_mode, + direction=self._direction) + + def _forward(self, inputs, h, c, opaque_params, training): + output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access + inputs, + h, + c, + opaque_params, + training, + self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction, + dropout=self._dropout, + seed=self._seed) + return output, (output_h, output_c) + + def _create_saveable(self): + """Create custom saveable for the Cudnn layer. + + Called during layer building process to make sharing checkpoints between + Cudnn and Cudnn-compatible RNNs easy. + Returns: + a `CudnnOpaqueParamsSaveable` object. + Raises: + RuntimeError: if any custom saveable is already created for this layer. + """ + if self._saveable is not None: + raise RuntimeError("Cudnn saveable already created.") + self._saveable = self._saveable_cls( # pylint:disable=not-callable + self.trainable_variables[0], + self.num_layers, + self.num_units, + self.input_size, + self.input_mode, + self.direction, + scope=vs.get_variable_scope(), + name="%s_saveable" % self.trainable_variables[0].op.name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) + + +class CudnnLSTM(_CudnnRNN): + """Cudnn implementation of LSTM layer.""" + _rnn_mode = CUDNN_LSTM + _num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnLSTMSaveable + + def state_shape(self, batch_size): + """Shape of Cudnn LSTM states. + + Shape is a 2-element tuple. Each is + [num_layers * num_dirs, batch_size, num_units] + Args: + batch_size: an int + Returns: + a tuple of python arrays. + """ + return ([self.num_layers * self.num_dirs, batch_size, self.num_units], + [self.num_layers * self.num_dirs, batch_size, self.num_units]) + + +class _CudnnRNNNoInputC(_CudnnRNN): + """Abstract simple CudnnRNN layer without input_c.""" + + def state_shape(self, batch_size): + """Shape of the state of Cudnn RNN cells w/o. input_c. + + Shape is a 1-element tuple, + [num_layers * num_dirs, batch_size, num_units] + Args: + batch_size: an int + Returns: + a tuple of python arrays. + """ + return [self.num_layers * self.num_dirs, batch_size, self.num_units], + + +class CudnnGRU(_CudnnRNNNoInputC): + """Cudnn implementation of the GRU layer.""" + _rnn_mode = CUDNN_GRU + _num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnGRUSaveable + + +class CudnnRNNTanh(_CudnnRNNNoInputC): + """Cudnn implementation of the RNN-tanh layer.""" + _rnn_mode = CUDNN_RNN_TANH + _num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnRNNTanhSaveable + + +class CudnnRNNRelu(_CudnnRNNNoInputC): + """Cudnn implementation of the RNN-relu layer.""" + _rnn_mode = CUDNN_RNN_RELU + _num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnRNNReluSaveable diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index bbf1bd9bca14320b4a2f8f9d30340c1fa64eb3da..7d658c746ee1ecd21cefca9c9e52f611869f6176 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -717,12 +717,6 @@ _cudnn_rnn_common_doc_string = """ """ -def _check_direction(direction): - if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): - raise ValueError("Invalid direction: %s, expect %s or %s" % - (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)) - - def _check_rnn_mode(rnn_mode): if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU): raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" % @@ -737,14 +731,31 @@ def _get_seed(seed): return seed, seed2 +def check_direction(direction): + """Check validity of direction.""" + if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): + raise ValueError("Invalid direction: %s, expecting %s or %s" % + (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)) + + +def check_input_mode(input_mode): + if input_mode not in (CUDNN_INPUT_LINEAR_MODE, CUDNN_INPUT_SKIP_MODE, + CUDNN_INPUT_AUTO_MODE): + raise ValueError("Invalid input_mode: %s, expect one of (%s, %s, %s)" % + (input_mode, CUDNN_INPUT_LINEAR_MODE, + CUDNN_INPUT_SKIP_MODE, CUDNN_INPUT_AUTO_MODE)) + + def _get_num_params(rnn_mode, num_layers, direction): """Return num params for given Cudnn config.""" if rnn_mode == CUDNN_LSTM: - num_params_per_layer = 8 + num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER elif rnn_mode == CUDNN_GRU: - num_params_per_layer = 6 - elif rnn_mode in (CUDNN_RNN_RELU, CUDNN_RNN_TANH): - num_params_per_layer = 2 + num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_RELU: + num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_TANH: + num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER else: raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode) num_params = num_layers * num_params_per_layer @@ -794,7 +805,8 @@ def _cudnn_rnn(inputs, outputs, output_h, output_c """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn( input=inputs, @@ -1017,16 +1029,16 @@ def cudnn_rnn_tanh(inputs, seed, name) -def cudnn_rnn_params_to_canonical(rnn_mode, - num_layers, - num_units, - input_size, - params, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_opaque_params_to_canonical(rnn_mode, + num_layers, + num_units, + input_size, + params, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0, + seed=0, + name=None): """Convert cudnn opaque params to canonical. Args: @@ -1058,7 +1070,8 @@ def cudnn_rnn_params_to_canonical(rnn_mode, """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) num_params = _get_num_params(rnn_mode, num_layers, direction) seed, seed2 = random_seed.get_seed(seed) weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( @@ -1077,17 +1090,17 @@ def cudnn_rnn_params_to_canonical(rnn_mode, return weights, biases -def cudnn_rnn_canonical_to_params(rnn_mode, - num_layers, - num_units, - input_size, - weights, - biases, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_canonical_to_opaque_params(rnn_mode, + num_layers, + num_units, + input_size, + weights, + biases, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0, + seed=0, + name=None): """Converts params from the canonical format to a specific format of cuDNN. Args: @@ -1119,7 +1132,8 @@ def cudnn_rnn_canonical_to_params(rnn_mode, ValueError: if rnn_mode or direction is invalid. """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( rnn_mode=rnn_mode, @@ -1136,16 +1150,16 @@ def cudnn_rnn_canonical_to_params(rnn_mode, name=name) -def cudnn_opaque_params_size(rnn_mode, - num_layers, - num_units, - input_size, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dtype=dtypes.float32, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_opaque_params_size(rnn_mode, + num_layers, + num_units, + input_size, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dtype=dtypes.float32, + dropout=0, + seed=0, + name=None): """Returns opaque params size for specific Cudnn config. Args: @@ -1176,7 +1190,8 @@ def cudnn_opaque_params_size(rnn_mode, ValueError: if rnn_mode or direction is invalid. """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) return gen_cudnn_rnn_ops.cudnn_rnn_params_size( rnn_mode=rnn_mode, @@ -1278,7 +1293,7 @@ class _CudnnRNN(object): Returns: The calculated parameter buffer size. """ - return cudnn_opaque_params_size( + return cudnn_rnn_opaque_params_size( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, @@ -1327,7 +1342,7 @@ class _CudnnRNN(object): Returns: A function for the specific-to-canonical conversion. """ - return cudnn_rnn_params_to_canonical( + return cudnn_rnn_opaque_params_to_canonical( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, @@ -1348,7 +1363,7 @@ class _CudnnRNN(object): Returns: A function for the canonical-to-params-to-specific conversion.. """ - return cudnn_rnn_canonical_to_params( + return cudnn_rnn_canonical_to_opaque_params( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index ee96269a739ebb138ea88cf4e192f7925e85447d..b485d78f5c4a6867e00c5b4ad04f18c92af953a1 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -10,6 +10,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:util", diff --git a/tensorflow/contrib/data/README.md b/tensorflow/contrib/data/README.md index 04f0560b09297c65cd593232ba9f8daab8a0107a..30e909111f460bb4d0ea5fcdefaf5bdedc93b9c0 100644 --- a/tensorflow/contrib/data/README.md +++ b/tensorflow/contrib/data/README.md @@ -2,9 +2,38 @@ ===================== NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead. +We are continuing to support existing code using the `tf.contrib.data` APIs in +the current version of TensorFlow, but will eventually remove support. The +`tf.data` APIs are subject to backwards compatibility guarantees. -This directory contains the Python API for the `tf.contrib.data.Dataset` and -`tf.contrib.data.Iterator` classes, which can be used to build input pipelines. +Porting your code to `tf.data` +------------------------------ -The documentation for `tf.data` API has moved to the programmers' -guide, [here](../../docs_src/programmers_guide/datasets.md). +The `tf.contrib.data.Dataset` class has been renamed to `tf.data.Dataset`, and +the `tf.contrib.data.Iterator` class has been renamed to `tf.data.Iterator`. +Most code can be ported by removing `.contrib` from the names of the classes. +However, there are some small differences, which are outlined below. + +The arguments accepted by the `Dataset.map()` transformation have changed: + +* `dataset.map(..., num_threads=T)` is now `dataset.map(num_parallel_calls=T)`. +* `dataset.map(..., output_buffer_size=B)` is now + `dataset.map(...).prefetch(B). + +Some transformations have been removed from `tf.data.Dataset`, and you must +instead apply them using `Dataset.apply()` transformation. The full list of +changes is as follows: + +* `dataset.dense_to_sparse_batch(...)` is now + `dataset.apply(tf.contrib.data.dense_to_sparse_batch(...)`. +* `dataset.enumerate(...)` is now + `dataset.apply(tf.contrib.data.enumerate_dataset(...))`. +* `dataset.group_by_window(...)` is now + `dataset.apply(tf.contrib.data.group_by_window(...))`. +* `dataset.ignore_errors()` is now + `dataset.apply(tf.contrib.data.ignore_errors())`. +* `dataset.unbatch()` is now `dataset.apply(tf.contrib.data.unbatch())`. + +The `Dataset.make_dataset_resource()` and `Iterator.dispose_op()` methods have +been removed from the API. Please open a GitHub issue if you have a need for +either of these. diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 4c32c72ad452b40d5f6b3dc0ab5be80a5ed23998..6c46acf20442c2cc435829afa57e8383b493d6af 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -27,11 +27,13 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@enumerate_dataset @@group_by_window @@ignore_errors +@@make_saveable_from_iterator @@read_batch_features @@unbatch @@rejection_resample @@sloppy_interleave +@@get_single_element """ from __future__ import absolute_import @@ -44,16 +46,18 @@ from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch from tensorflow.contrib.data.python.ops.batching import unbatch from tensorflow.contrib.data.python.ops.dataset_ops import Dataset +from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.grouping import group_by_window +from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave +from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset from tensorflow.contrib.data.python.ops.readers import read_batch_features from tensorflow.contrib.data.python.ops.readers import SqlDataset from tensorflow.contrib.data.python.ops.readers import TextLineDataset from tensorflow.contrib.data.python.ops.readers import TFRecordDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample -from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave from tensorflow.python.data.ops.iterator_ops import Iterator # pylint: enable=unused-import diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index c34c9dad9b5afb1f1232c8bff4c26770199ce7b6..ff59e80b7994c510c9dbaf13be2cd475536485b7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -143,6 +143,29 @@ py_test( ], ) +py_test( + name = "interleave_dataset_op_test", + size = "small", + srcs = ["interleave_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "manual", # b/67958761 + ], + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + py_test( name = "iterator_ops_cluster_test", size = "small", @@ -185,6 +208,7 @@ py_test( "//tensorflow/python:function", "//tensorflow/python:functional_ops", "//tensorflow/python:gradients", + "//tensorflow/python:io_ops", "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:script_ops", @@ -244,6 +268,7 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -252,8 +277,11 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", "//tensorflow/python:variables", "//tensorflow/python/data/ops:iterator_ops", ], @@ -265,6 +293,7 @@ py_test( srcs = ["reader_dataset_ops_test.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -274,9 +303,11 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", ], @@ -294,11 +325,8 @@ py_test( "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", "//tensorflow/python:string_ops", - "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python:variables", "//third_party/py/numpy", ], ) @@ -347,26 +375,6 @@ py_test( ], ) -py_test( - name = "sloppy_transformation_dataset_op_test", - size = "small", - srcs = ["sloppy_transformation_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:math_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], -) - py_test( name = "sql_dataset_op_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 91f100e0f0ccd452ab9a9e673d6714b718ccbeb2..add17ff8bcea0f228dc36ec6157fe95b9ce44d80 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -256,8 +256,9 @@ class BatchDatasetTest(test.TestCase): def testDenseToSparseBatchDatasetWithUnknownShape(self): components = np.random.randint(5, size=(40,)).astype(np.int32) iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([x, x], x)).dense_to_sparse_batch( - 4, [5, -1]).make_initializable_iterator()) + .map(lambda x: array_ops.fill([x, x], x)).apply( + batching.dense_to_sparse_batch( + 4, [5, -1])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -285,7 +286,8 @@ class BatchDatasetTest(test.TestCase): def testDenseToSparseBatchDatasetWithInvalidShape(self): input_tensor = array_ops.constant([[1]]) iterator = (dataset_ops.Dataset.from_tensors(input_tensor) - .dense_to_sparse_batch(4, [-2]).make_initializable_iterator()) + .apply(batching.dense_to_sparse_batch(4, [-2])) + .make_initializable_iterator()) init_op = iterator.initializer with self.test_session() as sess: @@ -424,6 +426,102 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) + def testBatchAndMapDataset(self): + """Test a dataset that maps a TF function across its input elements.""" + # The pipeline is TensorSliceDataset -> + # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) + + count = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + iterator = (dataset_ops.Dataset.from_tensor_slices(components).repeat(count) + .apply(batching.map_and_batch(_map_fn, batch_size)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([[None] + list(c.shape[1:]) for c in components], + [t.shape.as_list() for t in get_next]) + + with self.test_session() as sess: + # Batch of a finite input, where the batch_size divides the + # total number of elements. + sess.run(init_op, feed_dict={count: 28, batch_size: 14}) + num_batches = (28 * 7) // 14 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(14): + self.assertAllEqual(component[(i*14 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Batch of a finite input, where the batch_size does not + # divide the total number of elements. + sess.run(init_op, feed_dict={count: 14, batch_size: 8}) + + # We expect (num_batches - 1) full-sized batches. + num_batches = int(math.ceil((14 * 7) / 8)) + for i in range(num_batches - 1): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(8): + self.assertAllEqual(component[(i*8 + j) % 7]**2, + result_component[j]) + # The last batch should fail with `OutOfRange`. + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Batch of an empty input should fail straight away. + sess.run(init_op, feed_dict={count: 0, batch_size: 8}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Empty batch should be an initialization time error. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + + def testBatchAndMapDatasetFails(self): + """Test a dataset that maps a TF function across its input elements.""" + dataset = dataset_ops.Dataset.from_tensors( + array_ops.check_numerics( + constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = (dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + .make_initializable_iterator()) + init_op = iterator.initializer + with self.test_session() as sess: + with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): + sess.run(init_op, feed_dict={batch_size: 14}) + + def testBatchAndMapDatasetShapeMismatch(self): + """Test a dataset that maps a TF function across its input elements.""" + def generator(): + yield [1] + yield [2] + yield [3] + yield [[4, 5, 6]] + + dataset = dataset_ops.Dataset.from_generator( + generator, output_types=dtypes.int32) + batch_size = 4 + iterator = ( + dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "number of elements does not match"): + sess.run(get_next) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py similarity index 84% rename from tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 880e01dc069a70ac4ccbbbc18865f631ddea74d8..0aa9ea88de82b0851b0236d9412039d6573ab291 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -25,7 +25,7 @@ import time from six.moves import zip_longest from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import sloppy_ops +from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops @@ -34,12 +34,13 @@ from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -class SloppyInterleaveDatasetTest(test.TestCase): +class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): self.input_values = array_ops.placeholder(dtypes.int64, shape=[None]) self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) self.block_length = array_ops.placeholder(dtypes.int64, shape=[]) + self.sloppy = array_ops.placeholder(dtypes.bool, shape=[]) self.repeat_count = 2 @@ -69,9 +70,9 @@ class SloppyInterleaveDatasetTest(test.TestCase): self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values) .repeat(self.repeat_count).apply( - sloppy_ops.sloppy_interleave( + interleave_ops.parallel_interleave( interleave_fn, self.cycle_length, - self.block_length))) + self.block_length, self.sloppy))) self.iterator = self.dataset.make_initializable_iterator() self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() @@ -161,7 +162,7 @@ class SloppyInterleaveDatasetTest(test.TestCase): for i in range(4, 7): self.write_coordination_events[i].set() - def testSingleThreaded(self): + def _testSingleThreaded(self, sloppy=False): # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and # `Dataset.flat_map()` and is single-threaded. No synchronization required. with self.test_session() as sess: @@ -171,7 +172,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 1, - self.block_length: 1 + self.block_length: 1, + self.sloppy: sloppy }) for expected_element in self._interleave( @@ -182,7 +184,13 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testTwoThreadsNoContention(self): + def testSingleThreaded(self): + self._testSingleThreaded() + + def testSingleThreadedSloppy(self): + self._testSingleThreaded(sloppy=True) + + def _testTwoThreadsNoContention(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior with self.test_session() as sess: @@ -193,7 +201,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 1 + self.block_length: 1, + self.sloppy: sloppy }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -211,11 +220,20 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testTwoThreadsNoContentionWithRaces(self): + def testTwoThreadsNoContention(self): + self._testTwoThreadsNoContention() + + def testTwoThreadsNoContentionSloppy(self): + self._testTwoThreadsNoContention(sloppy=True) + + def _testTwoThreadsNoContentionWithRaces(self, sloppy=False): """Tests where all the workers race in producing elements. Note: this is in contrast with the prevous test which carefully sequences the execution of the map functions. + + Args: + sloppy: Whether to be sloppy or not. """ with self.test_session() as sess: self._clear_coordination_events() @@ -225,7 +243,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 1 + self.block_length: 1, + self.sloppy: sloppy, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -247,7 +266,13 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testTwoThreadsNoContentionBlockLength(self): + def testTwoThreadsNoContentionWithRaces(self): + self._testTwoThreadsNoContentionWithRaces() + + def testTwoThreadsNoContentionWithRacesSloppy(self): + self._testTwoThreadsNoContentionWithRaces(sloppy=True) + + def _testTwoThreadsNoContentionBlockLength(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior with self.test_session() as sess: @@ -258,7 +283,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 2 + self.block_length: 2, + self.sloppy: sloppy }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -276,11 +302,21 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testTwoThreadsNoContentionWithRacesAndBlocking(self): + def testTwoThreadsNoContentionBlockLength(self): + self._testTwoThreadsNoContentionBlockLength() + + def testTwoThreadsNoContentionBlockLengthSloppy(self): + self._testTwoThreadsNoContentionBlockLength(sloppy=True) + + def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False): """Tests where all the workers race in producing elements. Note: this is in contrast with the prevous test which carefully sequences the execution of the map functions. + + + Args: + sloppy: Whether to be sloppy or not. """ with self.test_session() as sess: self._clear_coordination_events() @@ -290,7 +326,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 2 + self.block_length: 2, + self.sloppy: sloppy }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -312,7 +349,13 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testEmptyInput(self): + def testTwoThreadsNoContentionWithRacesAndBlocking(self): + self._testTwoThreadsNoContentionWithRacesAndBlocking() + + def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self): + self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True) + + def _testEmptyInput(self, sloppy=False): with self.test_session() as sess: # Empty input. self._clear_coordination_events() @@ -321,12 +364,19 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [], self.cycle_length: 2, - self.block_length: 3 + self.block_length: 3, + self.sloppy: sloppy }) with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testNonEmptyInputIntoEmptyOutputs(self): + def testEmptyInput(self): + self._testEmptyInput() + + def testEmptyInputSloppy(self): + self._testEmptyInput(sloppy=True) + + def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False): # Non-empty input leading to empty output. with self.test_session() as sess: self._clear_coordination_events() @@ -335,12 +385,19 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [0, 0, 0], self.cycle_length: 2, - self.block_length: 3 + self.block_length: 3, + self.sloppy: sloppy }) with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testPartiallyEmptyOutputs(self): + def testNonEmptyInputIntoEmptyOutputs(self): + self._testNonEmptyInputIntoEmptyOutputs() + + def testNonEmptyInputIntoEmptyOutputsSloppy(self): + self._testNonEmptyInputIntoEmptyOutputs(sloppy=True) + + def _testPartiallyEmptyOutputs(self, sloppy=False): # Mixture of non-empty and empty interleaved datasets. with self.test_session() as sess: self._clear_coordination_events() @@ -350,7 +407,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 0, 6], self.cycle_length: 2, - self.block_length: 1 + self.block_length: 1, + self.sloppy: sloppy, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): @@ -367,7 +425,13 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testDelayedOutput(self): + def testPartiallyEmptyOutputs(self): + self._testPartiallyEmptyOutputs() + + def testPartiallyEmptyOutputsSloppy(self): + self._testPartiallyEmptyOutputs(sloppy=True) + + def testDelayedOutputSloppy(self): # Explicitly control the sequence of events to ensure we correctly avoid # head-of-line blocking. with self.test_session() as sess: @@ -377,7 +441,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 1 + self.block_length: 1, + self.sloppy: True, }) mis_ordering = [ @@ -391,7 +456,7 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testBlockLengthWithContention(self): + def testBlockLengthWithContentionSloppy(self): with self.test_session() as sess: self._clear_coordination_events() done_first_event = False @@ -400,7 +465,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 3 + self.block_length: 3, + self.sloppy: True }) # Test against a generating sequence that differs from the uncontended # case, in order to prove sloppy correctness. @@ -422,7 +488,7 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testEarlyExit(self): + def _testEarlyExit(self, sloppy=False): # Exiting without consuming all input should not block with self.test_session() as sess: self._clear_coordination_events() @@ -431,7 +497,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 3, - self.block_length: 2 + self.block_length: 2, + self.sloppy: sloppy }) for i in range(4, 7): self.write_coordination_events[i].set() @@ -445,7 +512,13 @@ class SloppyInterleaveDatasetTest(test.TestCase): self.read_coordination_events[i].acquire() self.write_coordination_events[i].set() - def testTooManyReaders(self): + def testEarlyExit(self): + self._testEarlyExit() + + def testEarlyExitSloppy(self): + self._testEarlyExit(sloppy=True) + + def _testTooManyReaders(self, sloppy=False): def interleave_fn(x): dataset = dataset_ops.Dataset.from_tensors(x) @@ -455,8 +528,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6]) dataset = dataset.repeat(self.repeat_count) dataset = dataset.apply( - sloppy_ops.sloppy_interleave(interleave_fn, cycle_length=16, - block_length=2)) + interleave_ops.parallel_interleave( + interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy)) iterator = dataset.make_one_shot_iterator() with self.test_session() as sess: @@ -468,6 +541,11 @@ class SloppyInterleaveDatasetTest(test.TestCase): [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2) self.assertItemsEqual(output_values, expected_values) + def testTooManyReaders(self): + self._testTooManyReaders() + + def testTooManyReadersSloppy(self): + self._testTooManyReaders(sloppy=True) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 8d8cb574eaad565a1c1fce9b81efe25f974bcaea..bda9a2a4a37e9c3d35ff99041d1150ffc43f4c43 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import script_ops @@ -538,9 +539,23 @@ class IteratorTest(test.TestCase): def testIncorrectIteratorRestore(self): - def _iterator_checkpoint_prefix(): + def _path(): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + _path(), parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def _build_range_dataset_graph(): start = 1 stop = 10 @@ -548,22 +563,18 @@ class IteratorTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = _iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op def _build_reader_dataset_graph(): filenames = ["test"] # Does not exist but we don't care in this test. - path = _iterator_checkpoint_prefix() iterator = readers.FixedLengthRecordDataset( filenames, 1, 0, 0).make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op # Saving iterator for RangeDataset graph. @@ -584,6 +595,31 @@ class IteratorTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(restore_op) + def testToSingleElement(self): + skip_value = array_ops.placeholder(dtypes.int64, shape=[]) + take_value = array_ops.placeholder_with_default( + constant_op.constant(1, dtype=dtypes.int64), shape=[]) + + dataset = (dataset_ops.Dataset.range(100) + .skip(skip_value) + .map(lambda x: x * x) + .take(take_value)) + + element = dataset_ops.get_single_element(dataset) + + with self.test_session() as sess: + self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) + self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) + self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset was empty."): + sess.run(element, feed_dict={skip_value: 100}) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset had more than one element."): + sess.run(element, feed_dict={skip_value: 0, take_value: 2}) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index c8a0072809c2eac30e255d29ecaee5a324449045..f59ac760dc83a504e563f055b91f1002cb0c80fc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -21,6 +21,7 @@ import os from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,9 +30,12 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib class RangeDatasetTest(test.TestCase): @@ -193,6 +197,21 @@ class RangeDatasetTest(test.TestCase): def _iterator_checkpoint_prefix(self): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def testSaveRestore(self): def _build_graph(start, stop): @@ -200,10 +219,8 @@ class RangeDatasetTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -244,16 +261,146 @@ class RangeDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testSaveRestoreUsingSaverFromMetaGraph(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next) + saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) + # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection + # so that it can be automatically picked up by the Saver. + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) + saver = saver_lib.Saver() + return init_op, get_next, saver + + start = 2 + stop = 10 + break_point = 5 + path = self._iterator_checkpoint_prefix() + meta_filename = path + ".meta" + + # Execute input pipeline for a few steps and save iterator state. + with ops.Graph().as_default() as g: + init_op, get_next, saver = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + saver.save(sess, path) + + # Build the saver from the MetaGraph using import_meta_graph and + # check that the iterator state is restored. + with ops.Graph().as_default() as g: + saver = saver_lib.import_meta_graph(meta_filename) + init_op, get_next = ops.get_collection("iterator_ops") + with self.test_session(graph=g) as sess: + saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSaveRestoreUsingBuiltSaver(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next) + # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection + # so that it can be automatically picked up by the Saver. + saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) + saver = saver_lib.Saver() + return init_op, get_next, saver + + start = 2 + stop = 10 + stop_new = 15 + break_point = 5 + path = self._iterator_checkpoint_prefix() + + # Execute input pipeline for a few steps and save iterator state. + with ops.Graph().as_default() as g: + init_op, get_next, saver = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + saver.save(sess, path) + + # Manually build a modified Graph and Saver instead of importing + # MetaGraph and verify that original iterator state gets restored. + with ops.Graph().as_default() as g: + init_op, get_next, saver = _build_graph(start, stop_new) + with self.test_session(graph=g) as sess: + saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSaveRestoreUsingSaverThenInit(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next) + # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection + # so that it can be automatically picked up by the Saver. + saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) + saver = saver_lib.Saver() + return init_op, get_next, saver + + start = 2 + stop = 10 + stop_new = 15 + break_point = 5 + path = self._iterator_checkpoint_prefix() + + # Execute input pipeline for a few steps and save iterator state. + with ops.Graph().as_default() as g: + init_op, get_next, saver = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + saver.save(sess, path) + + # Restore iterator state call and then call init_op for the iterator and + # verify that the new iterator hides the restored iterator. + with ops.Graph().as_default() as g: + init_op, get_next, saver = _build_graph(start, stop_new) + with self.test_session(graph=g) as sess: + saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) + sess.run(init_op) + for i in range(start, stop_new): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testRestoreWithoutBuildingDatasetGraph(self): - def _build_graph(start, stop, num_epochs, path): + def _build_graph(start, stop, num_epochs): dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -262,10 +409,8 @@ class RangeDatasetTest(test.TestCase): num_epochs = 5 break_point = 5 break_epoch = 3 - path = self._iterator_checkpoint_prefix() with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs, - path) + init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs) with self.test_session(graph=g) as sess: sess.run(variables.global_variables_initializer()) sess.run(init_op) @@ -282,8 +427,7 @@ class RangeDatasetTest(test.TestCase): output_shapes = tensor_shape.scalar() iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + restore_op = self._restore_op(iterator._iterator_resource) get_next = iterator.get_next() with self.test_session(graph=g) as sess: sess.run(restore_op) @@ -302,10 +446,8 @@ class RangeDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -343,10 +485,8 @@ class RangeDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -379,10 +519,8 @@ class RangeDatasetTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 @@ -424,10 +562,8 @@ class RangeDatasetTest(test.TestCase): start, stop).repeat(num_epochs).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 @@ -471,10 +607,8 @@ class RangeDatasetTest(test.TestCase): start, stop).repeat(num_epochs).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index c9f88f3dfc9a062ccd0bcabe7eadf18c98191c1d..3ae8f71d77fa6ecf08e42bedac702b8f75eec309 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,6 +21,7 @@ import gzip import os import zlib +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 @@ -33,8 +34,10 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import compat @@ -162,6 +165,277 @@ class TextLineDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _latest_ckpt(self): + return saver_lib.latest_checkpoint(self.get_temp_dir()) + + def _save(self, saver, sess): + saver.save(sess, self._ckpt_path()) + + def _restore(self, saver, sess): + saver.restore(sess, self._latest_ckpt()) + + def _import_meta_graph(self): + meta_file_path = self._ckpt_path() + ".meta" + return saver_lib.import_meta_graph(meta_file_path) + + def _build_graph(self, + test_filenames, + compression_type=None, + build_saveable=True): + ds = readers.TextLineDataset( + test_filenames, compression_type=compression_type, buffer_size=10) + iterator = ds.make_initializable_iterator() + if build_saveable: + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next) + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + def _testReadWithBreaks(self, breaks, num_files=5, lines_per_file=5): + """Tests reading from input pipeline with regular breaks. + + At each break point the iterator state gets saved using Saver and reloaded + in a new Graph and session. + + Args: + breaks: List of counts of records after reading which iterator state is + checkpointed. Must to in non-decreasing order. + num_files: Total number of files. + lines_per_file: Total number of lines per file. + """ + compression_types = [None, "GZIP", "ZLIB"] + for compression_type in compression_types: + test_filenames = self._createFiles( + num_files, + lines_per_file, + crlf=True, + compression_type=compression_type) + + # Collect ground truth. + total_records = num_files * lines_per_file + expected_records = [] + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph( + test_filenames, compression_type=compression_type) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(total_records): + expected_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Simulate run with breaks. + actual_records = [] + next_record_index = 0 + load_from_ckpt = False + breaks.append(total_records) + for break_index in breaks: + with ops.Graph().as_default() as g: + if not load_from_ckpt: + init_op, get_next, saver = self._build_graph( + test_filenames, compression_type=compression_type) + else: + saver = self._import_meta_graph() + init_op, get_next = ops.get_collection("iterator_ops") + + with self.test_session(graph=g) as sess: + if not load_from_ckpt: + sess.run(init_op) + else: + self._restore(saver, sess) + while next_record_index != break_index: + actual_records.append(sess.run(get_next)) + next_record_index += 1 + if break_index == total_records: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self._save(saver, sess) + load_from_ckpt = True + self.assertEqual(actual_records, expected_records) + + def testSaveAtFileBoundary(self): + self._testReadWithBreaks([10]) + + def testSaveWithinFile(self): + self._testReadWithBreaks([12]) + + def testSaveUnusedIterator(self): + self._testReadWithBreaks([0]) + + def testSaveRestoreIdempotence(self): + # Attempt to save an iterator immediately after it has been + # restored. + self._testReadWithBreaks([0, 0]) + self._testReadWithBreaks([10, 10]) + self._testReadWithBreaks([12, 12]) + + def testMultipleBreaks(self): + self._testReadWithBreaks([0, 4, 20]) + + def testRestoreExhaustedIterator(self): + num_files = 2 + lines_per_file = 5 + test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) + + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph(test_filenames) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(num_files * lines_per_file): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self._save(saver, sess) + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + saver = self._import_meta_graph() + self._restore(saver, sess) + _, get_next = ops.get_collection("iterator_ops") + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testInitThenRestore(self): + num_files = 5 + lines_per_file = 5 + total_records = num_files * lines_per_file + break_record = 8 + test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) + + expected_records = [] + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph(test_filenames) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_record): + sess.run(get_next) + self._save(saver, sess) + for _ in range(total_records - break_record): + expected_records.append(sess.run(get_next)) + + actual_records = [] + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + saver = self._import_meta_graph() + init_op, get_next = ops.get_collection("iterator_ops") + sess.run(init_op) + self._restore(saver, sess) + for _ in range(total_records - break_record): + actual_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self.assertEqual(actual_records, expected_records) + + def testRestoreInModifiedGraph(self): + num_files = 5 + lines_per_file = 5 + total_records = num_files * lines_per_file + break_record = 8 + test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) + + expected_records = [] + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph(test_filenames) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_record): + sess.run(get_next) + self._save(saver, sess) + for _ in range(total_records - break_record): + expected_records.append(sess.run(get_next)) + + actual_records = [] + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + init_op, get_next, saver = self._build_graph( + test_filenames, compression_type="GZIP") + self._restore(saver, sess) + for _ in range(total_records - break_record): + actual_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self.assertEqual(actual_records, expected_records) + + def testRestoreInModifiedGraphThenInit(self): + num_files = 5 + lines_per_file = 5 + total_records = num_files * lines_per_file + break_record = 8 + test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) + + expected_records = [] + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph(test_filenames) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_record): + expected_records.append(sess.run(get_next)) + self._save(saver, sess) + for _ in range(total_records - break_record): + expected_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that calling the init_op overrides the restored iterator. The + # iterator for the old graph was build to read uncompressed files and + # would fail when trying to read the new files. + actual_records = [] + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + test_filenames = self._createFiles( + num_files, lines_per_file, crlf=True, compression_type="GZIP") + init_op, get_next, saver = self._build_graph( + test_filenames, compression_type="GZIP") + self._restore(saver, sess) + sess.run(init_op) + for _ in range(total_records): + actual_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self.assertEqual(actual_records, expected_records) + + def testDoNotRestoreIterator(self): + num_files = 5 + lines_per_file = 5 + total_records = num_files * lines_per_file + break_record = 8 + test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) + + expected_records = [] + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph(test_filenames) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_record): + expected_records.append(sess.run(get_next)) + self._save(saver, sess) + for _ in range(total_records - break_record): + expected_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + actual_records = [] + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + init_op, get_next, saver = self._build_graph( + test_filenames, build_saveable=False) + self._restore(saver, sess) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(get_next) + sess.run(init_op) + for _ in range(total_records): + actual_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self.assertEqual(actual_records, expected_records) + class FixedLengthRecordReaderTest(test.TestCase): @@ -276,18 +550,31 @@ class FixedLengthRecordReaderTest(test.TestCase): def _iterator_checkpoint_path(self): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_path(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def _build_iterator_graph(self, num_epochs): filenames = self._createFiles() - path = self._iterator_checkpoint_path() dataset = (readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes) .repeat(num_epochs)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op def _restore_iterator(self): @@ -295,8 +582,7 @@ class FixedLengthRecordReaderTest(test.TestCase): output_shapes = tensor_shape.scalar() iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) get_next = iterator.get_next() - restore_op = gen_dataset_ops.restore_iterator( - iterator._iterator_resource, self._iterator_checkpoint_path()) + restore_op = self._restore_op(iterator._iterator_resource) return restore_op, get_next def testSaveRestore(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index a19c91707541f83003ef9af1cc2a8527a8e5f6c3..0ac8d7359f7234d98167277724780bf31555e6fb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -22,11 +22,8 @@ import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.ops import string_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import device_setter from tensorflow.python.util import compat @@ -51,10 +48,8 @@ class ResampleTest(test.TestCase): seed=27)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - variable_init_op = variables.local_variables_initializer() with self.test_session() as sess: - sess.run(variable_init_op) sess.run(init_op) returned = [] with self.assertRaises(errors.OutOfRangeError): @@ -75,23 +70,6 @@ class ResampleTest(test.TestCase): returned_dist = class_counts / total_returned self.assertAllClose(target_dist, returned_dist, atol=1e-2) - def testVariableDevicePlacement(self): - classes = np.random.randint(5, size=(20000,)) # Uniformly sampled - target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] - with ops.device( - device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")): - _ = (dataset_ops.Dataset.from_tensor_slices(classes).shuffle( - 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).apply( - resampling.rejection_resample( - target_dist=target_dist, - initial_dist=None, - class_func=lambda c, _: c, - seed=27))) - - self.assertEqual(1, len(variables.local_variables())) - self.assertEqual(b"", - compat.as_bytes(variables.local_variables()[0].device)) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5338ec56bf275e481a984964e39aa0c1ade3a752 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -0,0 +1,128 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.contrib.data.python.ops import scan_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ScanDatasetTest(test.TestCase): + + def _count(self, start, step): + return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( + scan_ops.scan(start, lambda state, _: (state + step, state))) + + def testCount(self): + start = array_ops.placeholder(dtypes.int32, shape=[]) + step = array_ops.placeholder(dtypes.int32, shape=[]) + take = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = self._count(start, step).take(take).make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + + for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), + (10, 2, 10), (10, -1, 10), + (10, -2, 10)]: + sess.run(iterator.initializer, + feed_dict={start: start_val, step: step_val, take: take_val}) + for expected, _ in zip( + itertools.count(start_val, step_val), range(take_val)): + self.assertEqual(expected, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testFibonacci(self): + iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) + ).make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(1, sess.run(next_element)) + self.assertEqual(1, sess.run(next_element)) + self.assertEqual(2, sess.run(next_element)) + self.assertEqual(3, sess.run(next_element)) + self.assertEqual(5, sess.run(next_element)) + self.assertEqual(8, sess.run(next_element)) + + def testChangingStateShape(self): + # Test the fixed-point shape invariant calculations: start with + # initial values with known shapes, and use a scan function that + # changes the size of the state on each element. + def _scan_fn(state, input_value): + # Statically known rank, but dynamic length. + ret_longer_vector = array_ops.concat([state[0], state[0]], 0) + # Statically unknown rank. + ret_larger_rank = array_ops.expand_dims(state[1], 0) + return (ret_longer_vector, ret_larger_rank), (state, input_value) + + dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply( + scan_ops.scan(([0], 1), _scan_fn)) + self.assertEqual([None], dataset.output_shapes[0][0].as_list()) + self.assertIs(None, dataset.output_shapes[0][1].ndims) + self.assertEqual([], dataset.output_shapes[1].as_list()) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(5): + (longer_vector_val, larger_rank_val), _ = sess.run(next_element) + self.assertAllEqual([0] * (2**i), longer_vector_val) + self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testIncorrectStateType(self): + + def _scan_fn(state, _): + return constant_op.constant(1, dtype=dtypes.int64), state + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The element types for the new state must match the initial state."): + dataset.apply( + scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) + + def testIncorrectReturnType(self): + + def _scan_fn(unused_state, unused_input_value): + return constant_op.constant(1, dtype=dtypes.int64) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The scan function must return a pair comprising the new state and the " + "output value."): + dataset.apply( + scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 690cccbea389fee5acf8de8d904c5185380a6833..e0730488a1132dde2bdc67cb4ddb8d8abc0f8265 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -12,14 +12,25 @@ py_library( srcs_version = "PY2AND3", deps = [ ":transformation_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:script_ops", - "//tensorflow/python:tensor_shape", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", ], ) +py_library( + name = "iterator_ops", + srcs = [ + "iterator_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + ], +) + py_library( name = "readers", srcs = [ @@ -35,6 +46,7 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_shape", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", "//tensorflow/python/data/util:nest", @@ -48,8 +60,9 @@ py_library( "enumerate_ops.py", "error_ops.py", "grouping.py", + "interleave_ops.py", "resampling.py", - "sloppy_ops.py", + "scan_ops.py", ], srcs_version = "PY2AND3", deps = [ @@ -62,7 +75,6 @@ py_library( "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", - "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 16f01557a2b55cf8061b6ebb4c7ba4afbfc0f58b..abc9212a87550745490b974d25a929a66287f785 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -68,7 +68,7 @@ def dense_to_sparse_batch(batch_size, row_shape): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): @@ -87,7 +87,7 @@ def unbatch(): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): @@ -106,7 +106,7 @@ def unbatch(): def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). - Like @{tf.contrib.data.Dataset.batch}, this transformation combines + Like @{tf.data.Dataset.batch}, this transformation combines consecutive elements of this dataset into batches. However, if the batch size does not evenly divide the input dataset size, this transformation will drop the final smaller element. @@ -115,7 +115,7 @@ def batch_and_drop_remainder(batch_size): transformation and `Dataset.batch()`: ```python - dataset = tf.contrib.data.Dataset.range(200) + dataset = tf.data.Dataset.range(200) batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128)) print(batched.output_shapes) # ==> "(128,)" (the batch dimension is known) ``` @@ -130,7 +130,7 @@ def batch_and_drop_remainder(batch_size): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply} + @{tf.data.Dataset.apply} """ def _apply_fn(dataset): @@ -272,3 +272,79 @@ class _RestructuredDataset(dataset_ops.Dataset): @property def output_shapes(self): return self._output_shapes + + +class _MapAndBatchDataset(dataset_ops.MapDataset): + """A `Dataset` that maps a function over a batch of elements.""" + + def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): + """See `Dataset.map()` for details.""" + super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) + + self._batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + self._num_parallel_batches = ops.convert_to_tensor( + num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + + def _as_variant_tensor(self): + # pylint: disable=protected-access + input_resource = self._input_dataset._as_variant_tensor() + return gen_dataset_ops.map_and_batch_dataset( + input_resource, + self._map_func.captured_inputs, + f=self._map_func, + batch_size=self._batch_size, + num_parallel_batches=self._num_parallel_batches, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + # pylint: enable=protected-access + + @property + def output_shapes(self): + return nest.pack_sequence_as(self._output_shapes, [ + tensor_shape.vector(tensor_util.constant_value( + self._batch_size)).concatenate(s) + for s in nest.flatten(self._output_shapes) + ]) + + @property + def output_types(self): + return self._output_types + + +def map_and_batch(map_func, batch_size, num_parallel_batches=1): + """Fused implementation of `map` and `batch`. + + Maps `map_func` across `batch_size` consecutive elements of this dataset + and then combines them into a batch. Similarly to `batch_and_drop_remainder`, + if the batch size does not evenly divide the input dataset size, this + transformation will drop the final smaller element. + + + Functionally, it is equivalent to `map` followed by + `batch_and_drop_remainder`. However, by fusing the two transformations + together, the implementation can be more efficient. This transformation is a + stop gap solution for performance critical workloads. Once automatic input + pipeline optimization are implemented, the fusing of map and batch will not + need to be exposed at the API level and this method will be removed. + + Args: + map_func: A function mapping a nested structure of tensors to another + nested structure of tensors. + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the + number of batches to create in parallel. On one hand, higher values can + help mitigate the effect of stragglers. On the other hand, higher values + can increasing contention if CPU is scarce. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _MapAndBatchDataset(dataset, map_func, batch_size, + num_parallel_batches) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index b74dcd3be2bf9a49340986bce47b0ad8c74ecc06..45d6dbe7438957029b4d6b71e181cb1fc3596ecb 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -24,10 +24,8 @@ from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_io_ops -from tensorflow.python.ops import script_ops from tensorflow.python.util import deprecation @@ -138,124 +136,8 @@ class Dataset(dataset_ops.Dataset): Returns: A `Dataset`. """ - if not callable(generator): - raise TypeError("`generator` must be callable.") - if output_shapes is None: - output_shapes = nest.map_structure( - lambda _: tensor_shape.TensorShape(None), output_types) - else: - output_shapes = nest.map_structure_up_to( - output_types, tensor_shape.as_shape, output_shapes) - - flattened_types = nest.flatten(output_types) - flattened_shapes = nest.flatten(output_shapes) - - generator_state = dataset_ops.Dataset._GeneratorState(generator) - - def get_iterator_id_map_fn(unused_dummy): - """Creates a unique `iterator_id` for each pass over the dataset. - - The "iterator_id" disambiguates between multiple concurrently - existing iterators. - - Args: - unused_dummy: Ignored value. - - Returns: - A `tf.int64` tensor whose value uniquely identifies an iterator in - `generator_state`. - """ - return script_ops.py_func( - generator_state.get_next_id, [], dtypes.int64, stateful=True) - - def generator_map_fn(iterator_id_t): - """Generates the next element from iterator with ID `iterator_id_t`. - - We map this function across an infinite repetition of the - `iterator_id_t`, and raise `StopIteration` to terminate the iteration. - - Args: - iterator_id_t: A `tf.int64` tensor whose value uniquely identifies - the iterator in `generator_state` from which to generate an element. - - Returns: - A nested structure of tensors representing an element from the iterator. - """ - - def generator_py_func(iterator_id): - """A `py_func` that will be called to invoke the iterator.""" - try: - values = next(generator_state.get_iterator(iterator_id)) - except StopIteration: - generator_state.iterator_completed(iterator_id) - raise StopIteration("Iteration finished.") - - # Use the same _convert function from the py_func() implementation to - # convert the returned values to arrays early, so that we can inspect - # their values. - # pylint: disable=protected-access - ret_arrays = [ - script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype) - for ret, dtype in zip(nest.flatten_up_to(output_types, values), - flattened_types) - ] - # pylint: enable=protected-access - - # Additional type and shape checking to ensure that the components - # of the generated element match the `output_types` and `output_shapes` - # arguments. - for (ret_array, expected_dtype, expected_shape) in zip( - ret_arrays, flattened_types, flattened_shapes): - if ret_array.dtype != expected_dtype.as_numpy_dtype: - raise TypeError( - "`generator` yielded an element of type %s where an element " - "of type %s was expected." % (ret_array.dtype, - expected_dtype.as_numpy_dtype)) - if not expected_shape.is_compatible_with(ret_array.shape): - raise ValueError( - "`generator` yielded an element of shape %s where an element " - "of shape %s was expected." % (ret_array.shape, expected_shape)) - - return ret_arrays - - flat_values = script_ops.py_func( - generator_py_func, [iterator_id_t], flattened_types, stateful=True) - - # The `py_func()` op drops the inferred shapes, so we add them back in - # here. - if output_shapes is not None: - for ret_t, shape in zip(flat_values, flattened_shapes): - ret_t.set_shape(shape) - - return nest.pack_sequence_as(output_types, flat_values) - - # This function associates each traversal of `generator` with a unique - # iterator ID. - def flat_map_fn(iterator_id_t): - # First, generate an infinite dataset containing the iterator ID repeated - # forever. - repeated_id = Dataset.from_tensors(iterator_id_t).repeat(None) - - # The `generator_map_fn` gets the next element from the iterator with the - # relevant ID, and raises StopIteration when that iterator contains no - # more elements. - return repeated_id.map(generator_map_fn) - - # A single-element dataset that, each time it is evaluated, contains a - # freshly-generated and unique (for the returned dataset) int64 - # ID that will be used to identify the appropriate Python state, which - # is encapsulated in `generator_state`, and captured in - # `get_iterator_id_map_fn`. - dummy = 0 - id_dataset = Dataset.from_tensors(dummy).map(get_iterator_id_map_fn) - - # A dataset that contains all of the elements generated by a - # single iterator created from `generator`, identified by the - # iterator ID contained in `id_dataset`. Lifting the iteration - # into a flat_map here enables multiple repetitions and/or nested - # versions of the returned dataset to be created, because it forces - # the generation of a new ID for each version. - return id_dataset.flat_map(flat_map_fn) + return Dataset(dataset_ops.Dataset.from_generator( + generator, output_types, output_shapes)) @staticmethod @deprecation.deprecated(None, "Use `tf.data.Dataset.range()`.") @@ -760,3 +642,48 @@ class Dataset(dataset_ops.Dataset): if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`transformation_func` must return a Dataset.") return Dataset(dataset) + + +def get_single_element(dataset): + """Returns the single element in `dataset` as a nested structure of tensors. + + This function enables you to use a @{tf.data.Dataset} in a stateless + "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. + This can be useful when your preprocessing transformations are expressed + as a `Dataset`, and you want to use the transformation at serving time. + For example: + + ```python + input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) + + def preprocessing_fn(input_str): + # ... + return image, label + + dataset = (tf.data.Dataset.from_tensor_slices(input_batch) + .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) + .batch(BATCH_SIZE)) + + image_batch, label_batch = tf.contrib.data.get_single_element(dataset) + ``` + + Args: + dataset: A @{tf.data.Dataset} object containing a single element. + + Returns: + A nested structure of @{tf.Tensor} objects, corresponding to the single + element of `dataset`. + + Raises: + TypeError: if `dataset` is not a `tf.data.Dataset` object. + InvalidArgumentError (at runtime): if `dataset` does not contain exactly + one element. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + return nest.pack_sequence_as( + dataset.output_types, + gen_dataset_ops.dataset_to_single_element( + dataset._as_variant_tensor(), # pylint: disable=protected-access + output_types=nest.flatten(dataset.output_types), + output_shapes=nest.flatten(dataset.output_shapes))) diff --git a/tensorflow/contrib/data/python/ops/enumerate_ops.py b/tensorflow/contrib/data/python/ops/enumerate_ops.py index 40e7315f1f6f1fa67874589cdc7ae96382d80e28..ac2b386b81532b801139baa00fd5edd4ecd6ef0a 100644 --- a/tensorflow/contrib/data/python/ops/enumerate_ops.py +++ b/tensorflow/contrib/data/python/ops/enumerate_ops.py @@ -47,7 +47,7 @@ def enumerate_dataset(start=0): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index dffa8b7f7dc66635948a5ba52868ffea121ee164..238bb52b0205f9ab66f479f1b92e72ab6e38725b 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -30,7 +30,7 @@ def ignore_errors(): example: ```python - dataset = tf.contrib.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) + dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError. dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error")) @@ -42,7 +42,7 @@ def ignore_errors(): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 2cf7e8f4ee5fc820d6f8321695abfef53d88498e..6df7b22fb69bb14c41a26bd630a825442f67ee23 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -57,7 +57,7 @@ def group_by_window(key_func, Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. Raises: ValueError: if neither or both of {`window_size`, `window_size_func`} are diff --git a/tensorflow/contrib/data/python/ops/sloppy_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py similarity index 62% rename from tensorflow/contrib/data/python/ops/sloppy_ops.py rename to tensorflow/contrib/data/python/ops/interleave_ops.py index 01e234f1d0db27277e9a38e6a259b4b064b89eaa..74a919c1fff62cfa79b0877a3d081077ca6776f0 100644 --- a/tensorflow/contrib/data/python/ops/sloppy_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -23,14 +23,16 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util import deprecation -class SloppyInterleaveDataset(dataset_ops.Dataset): +class ParallelInterleaveDataset(dataset_ops.Dataset): """A `Dataset` that maps a function over its input and flattens the result.""" - def __init__(self, input_dataset, map_func, cycle_length, block_length): - """See `tf.contrib.data.sloppy_interleave()` for details.""" - super(SloppyInterleaveDataset, self).__init__() + def __init__(self, input_dataset, map_func, cycle_length, block_length, + sloppy): + """See `tf.contrib.data.parallel_interleave()` for details.""" + super(ParallelInterleaveDataset, self).__init__() self._input_dataset = input_dataset @function.Defun(*nest.flatten(input_dataset.output_types)) @@ -62,13 +64,16 @@ class SloppyInterleaveDataset(dataset_ops.Dataset): cycle_length, dtype=dtypes.int64, name="cycle_length") self._block_length = ops.convert_to_tensor( block_length, dtype=dtypes.int64, name="block_length") + self._sloppy = ops.convert_to_tensor( + sloppy, dtype=dtypes.bool, name="sloppy") def _as_variant_tensor(self): - return gen_dataset_ops.sloppy_interleave_dataset( + return gen_dataset_ops.parallel_interleave_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._map_func.captured_inputs, self._cycle_length, self._block_length, + self._sloppy, f=self._map_func, output_types=nest.flatten(self.output_types), output_shapes=nest.flatten(self.output_shapes)) @@ -82,6 +87,53 @@ class SloppyInterleaveDataset(dataset_ops.Dataset): return self._output_types +def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False): + """A parallel version of the `Dataset.interleave()` transformation. + + `parallel_interleave()` maps `map_func` across its input to produce nested + datasets, and outputs their elements interleaved. Unlike + @{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested + datasets in parallel, which increases the throughput, especially in the + presence of stragglers. Furthermore, the `sloppy` argument can be used to + improve performance, by relaxing the requirement that the outputs are produced + in a deterministic order, and allowing the implementation to skip over nested + datasets whose elements are not readily available when requested. + + Example usage: + + ```python + # Preprocess 4 files concurrently. + filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") + dataset = filenames.apply( + tf.contrib.data.parallel_interleave( + lambda filename: tf.data.TFRecordDataset(filename), + cycle_length=4)) + ``` + + WARNING: If `sloppy` is `True`, the order of produced elements is not + deterministic. + + Args: + map_func: A function mapping a nested structure of tensors to a `Dataset`. + cycle_length: The number of threads to interleave from in parallel. + block_length: The number of consecutive elements to pull from a thread + before advancing to the next thread. + sloppy: If false, elements are produced in deterministic order. Otherwise, + the implementation is allowed, for the sake of expediency, to produce + elements in a non-deterministic order. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return ParallelInterleaveDataset( + dataset, map_func, cycle_length, block_length, sloppy) + return _apply_fn + + +@deprecation.deprecated( + None, "Use `tf.contrib.data.parallel_interleave(..., sloppy=True)`.") def sloppy_interleave(map_func, cycle_length, block_length=1): """A non-deterministic version of the `Dataset.interleave()` transformation. @@ -102,6 +154,17 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): strictly obeys), producing an element from a different underlying dataset instead. + Example usage: + + ```python + # Preprocess 4 files concurrently. + filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") + dataset = filenames.apply( + tf.contrib.data.sloppy_interleave( + lambda filename: tf.data.TFRecordDataset(filename), + cycle_length=4)) + ``` + WARNING: The order of elements in the resulting dataset is not deterministic. Use `Dataset.interleave()` if you want the elements to have a deterministic order. @@ -118,9 +181,9 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): - return SloppyInterleaveDataset( - dataset, map_func, cycle_length, block_length) + return ParallelInterleaveDataset( + dataset, map_func, cycle_length, block_length, sloppy=True) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d736029fb035e573b70e8b19570e4e8ceca3c005 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -0,0 +1,77 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Iterator ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.training import saver + + +def make_saveable_from_iterator(iterator): + """Returns a SaveableObject for saving/restore iterator state using Saver. + + Args: + iterator: Iterator. + + For example: + + ```python + with tf.Graph().as_default(): + ds = tf.data.Dataset.range(10) + iterator = ds.make_initializable_iterator() + # Build the iterator SaveableObject. + saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator) + # Add the SaveableObject to the SAVEABLE_OBJECTS collection so + # it can be automatically saved using Saver. + tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) + saver = tf.train.Saver() + + while continue_training: + ... Perform training ... + if should_save_checkpoint: + saver.save() + ``` + + Note: When restoring the iterator, the existing iterator state is completely + discarded. This means that any changes you may have made to the Dataset + graph will be discarded as well! This includes the new Dataset graph + that you may have built during validation. So, while running validation, + make sure to run the initializer for the validation input pipeline after + restoring the checkpoint. + + Note: Not all iterators support checkpointing yet. Attempting to save the + state of an unsupported iterator will throw an error. + """ + return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access + + +class _Saveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject for saving/restoring iterator state.""" + + def __init__(self, iterator_resource): + serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) + specs = [ + saver.BaseSaverBuilder.SaveSpec(serialized_iterator, "", + iterator_resource.name + "-state") + ] + super(_Saveable, self).__init__(iterator_resource, specs, + iterator_resource.name) + + def restore(self, restored_tensors, unused_restored_shapes): + with ops.colocate_with(self.op): + return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index f4f2d4285443a5b2b5e02b3126f2edd9bd47a937..56f526a330bfbea7305b0754bfd114c5e97db506 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,7 +29,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import resource_variable_ops def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): @@ -48,7 +48,7 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): @@ -68,26 +68,20 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): num_classes = (target_dist_t.shape[0].value or array_ops.shape(target_dist_t)[0]) smoothing_constant = 10 - # Disable device functions and colocation constraints so that the variable - # will be placed with the eventual DT_VARIANT dataset tensor. - with ops.colocate_with(None, ignore_existing=True): - num_examples_per_class_seen = resource_variable_ops.ResourceVariable( - initial_value=array_ops.fill([num_classes], - np.int64(smoothing_constant)), - trainable=False, - collections=[ops.GraphKeys.LOCAL_VARIABLES], - name="local_class_count", - dtype=dtypes.int64) - - def update_estimate_and_tile(c): - return array_ops.tile( - array_ops.expand_dims( - _estimate_data_distribution(c, num_examples_per_class_seen), 0), - [dist_estimation_batch_size, 1]) + initial_examples_per_class_seen = array_ops.fill( + [num_classes], np.int64(smoothing_constant)) + + def update_estimate_and_tile(num_examples_per_class_seen, c): + updated_examples_per_class_seen, dist = _estimate_data_distribution( + c, num_examples_per_class_seen) + tiled_dist = array_ops.tile( + array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) + return updated_examples_per_class_seen, tiled_dist initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) - .map(update_estimate_and_tile).apply(batching - .unbatch())) + .apply(scan_ops.scan(initial_examples_per_class_seen, + update_estimate_and_tile)) + .apply(batching.unbatch())) acceptance_dist_ds = initial_dist_ds.map( lambda initial: _calculate_acceptance_probs(initial, target_dist_t)) @@ -174,20 +168,21 @@ def _estimate_data_distribution(c, num_examples_per_class_seen): Args: c: The class labels. Type `int32`, shape `[batch_size]`. - num_examples_per_class_seen: A `ResourceVariable` containing counts. - Type `int64`, shape `[num_classes]`. + num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, + containing counts. Returns: + num_examples_per_lass_seen: Updated counts. Type `int64`, shape + `[num_classes]`. dist: The updated distribution. Type `float32`, shape `[num_classes]`. """ num_classes = num_examples_per_class_seen.get_shape()[0].value - # Update the class-count based on what labels are seen in - # batch. But do this asynchronously to avoid performing a - # cross-device round-trip. Just use the cached value. - num_examples_per_class_seen = num_examples_per_class_seen.assign_add( - math_ops.reduce_sum( + # Update the class-count based on what labels are seen in batch. + num_examples_per_class_seen = math_ops.add( + num_examples_per_class_seen, math_ops.reduce_sum( array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) init_prob_estimate = math_ops.truediv( num_examples_per_class_seen, math_ops.reduce_sum(num_examples_per_class_seen)) - return math_ops.cast(init_prob_estimate, dtypes.float32) + dist = math_ops.cast(init_prob_estimate, dtypes.float32) + return num_examples_per_class_seen, dist diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5acaed48a3d73e93706bdd0b5b2d614b0c565ab7 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -0,0 +1,182 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Scan dataset transformation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +class _ScanDataset(dataset_ops.Dataset): + """A dataset that scans a function across its input.""" + + def __init__(self, input_dataset, initial_state, scan_func): + """See `scan()` for details.""" + super(_ScanDataset, self).__init__() + self._input_dataset = input_dataset + + with ops.name_scope("initial_state"): + self._initial_state = nest.pack_sequence_as(initial_state, [ + ops.convert_to_tensor(t, name="component_%d" % i) + for i, t in enumerate(nest.flatten(initial_state)) + ]) + + # Compute initial values for the state shapes and types based on + # the initial state. These will be refined by running + # `tf_scan_func` one or more times below. + self._state_shapes = nest.pack_sequence_as( + self._initial_state, + [t.shape for t in nest.flatten(self._initial_state)]) + self._state_types = nest.pack_sequence_as( + self._initial_state, + [t.dtype for t in nest.flatten(self._initial_state)]) + + # Will be populated by calling `tf_scan_func`. + self._output_shapes = None + self._output_types = None + + # Iteratively rerun the scan function until reaching a fixed pont on + # `self._state_shapes`. + need_to_rerun = True + while need_to_rerun: + + flat_state_shapes = nest.flatten(self._state_shapes) + flat_state_types = nest.flatten(self._state_types) + + # Create a list in which `tf_scan_func` will store the s + flat_new_state_shapes = [] + + @function.Defun( + *(flat_state_types + nest.flatten(input_dataset.output_types))) + def tf_scan_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the state and input_dataset. + for arg, shape in zip( + args, + flat_state_shapes + nest.flatten(input_dataset.output_shapes)): + arg.set_shape(shape) + + pivot = len(flat_state_shapes) + old_state = nest.pack_sequence_as(self._initial_state, args[:pivot]) + input_value = nest.pack_sequence_as(input_dataset.output_types, + args[pivot:]) + + ret = scan_func(old_state, input_value) + if not isinstance(ret, collections.Sequence) or len(ret) != 2: + raise TypeError("The scan function must return a pair comprising the " + "new state and the output value.") + new_state, output_value = ret + + flat_new_state = [ + ops.convert_to_tensor(t) for t in nest.flatten(new_state) + ] + flat_output_value = [ + ops.convert_to_tensor(t) for t in nest.flatten(output_value) + ] + + # Extract shape information from the returned values. + flat_new_state_shapes.extend([t.shape for t in flat_new_state]) + self._output_shapes = nest.pack_sequence_as( + output_value, [t.shape for t in flat_output_value]) + + # Extract and validate type information from the returned values. + for t, dtype in zip(flat_new_state, flat_state_types): + if t.dtype != dtype: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, nest.pack_sequence_as( + self._state_types, [t.dtype for t in flat_new_state]))) + self._output_types = nest.pack_sequence_as( + output_value, [t.dtype for t in flat_output_value]) + + return flat_new_state + flat_output_value + + # Use the private method that will execute `tf_scan_func` but delay + # adding it to the graph in case we need to rerun the function. + tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access + + weakened_state_shapes = [ + original.most_specific_compatible_shape(new) + for original, new in zip(flat_state_shapes, flat_new_state_shapes) + ] + + need_to_rerun = False + for original_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if original_shape.ndims is not None and ( + weakened_shape.ndims is None or + original_shape.as_list() != weakened_shape.as_list()): + need_to_rerun = True + break + + if need_to_rerun: + # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun + # `tf_scan_func`. + self._state_shapes = nest.pack_sequence_as(self._state_shapes, + weakened_state_shapes) + + self._scan_func = tf_scan_func + + def _as_variant_tensor(self): + input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + return gen_dataset_ops.scan_dataset( + input_t, + nest.flatten(self._initial_state), + self._scan_func.captured_inputs, + f=self._scan_func, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +def scan(initial_state, scan_func): + """A transformation that scans a function across an input dataset. + + This transformation is a stateful relative of @{tf.data.Dataset.map}. + In addition to mapping `scan_func` across the elements of the input dataset, + `scan()` accumulates one or more state tensors, whose initial values are + `initial_state`. + + Args: + initial_state: A nested structure of tensors, representing the initial state + of the accumulator. + scan_func: A function that maps `(old_state, input_element)` to + `(new_state, output_element). It must take two arguments and return a + pair of nested structures of tensors. The `new_state` must match the + structure of `initial_state`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _ScanDataset(dataset, initial_state, scan_func) + + return _apply_fn diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index aef73f05983e143448807679a461a0aa09fc0a59..4a4f3789016bed5db475da81b2448b682f158353 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -18,14 +18,20 @@ py_library( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", + "//tensorflow/python:clip_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", + "//tensorflow/python:template", "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python:variable_scope", "//tensorflow/python/ops/distributions", + "//tensorflow/python/ops/linalg", "//third_party/py/numpy", ], ) @@ -55,7 +61,9 @@ py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", "//tensorflow/python/ops/distributions", + "//tensorflow/python/ops/linalg", "//third_party/py/numpy", "@six_archive//:six", ], @@ -305,6 +313,8 @@ cuda_py_test( additional_deps = [ ":distributions_py", "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:client_testlib", ], ) @@ -795,6 +805,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "gumbel_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/gumbel_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "inline_test", size = "small", @@ -833,6 +862,38 @@ cuda_py_test( ], ) +cuda_py_test( + name = "masked_autoregressive_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/masked_autoregressive_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "permute_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/permute_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "power_transform_test", size = "small", @@ -852,6 +913,22 @@ cuda_py_test( ], ) +cuda_py_test( + name = "reshape_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/reshape_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "sigmoid_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index f33cc1de0abc82a3a8974dba4459a55fb4c2e82c..16f6533e57347a5fe41b017c9855d216fba9da82 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -28,8 +28,11 @@ from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.conditional_distribution import * from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * +from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform +from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse +from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag from tensorflow.contrib.distributions.python.ops.estimator import * from tensorflow.contrib.distributions.python.ops.geometric import * from tensorflow.contrib.distributions.python.ops.independent import * @@ -140,13 +143,14 @@ _allowed_symbols = [ 'RelaxedOneHotCategorical', 'kl_divergence', 'RegisterKL', - 'matrix_diag_transform', 'fill_triangular', + 'matrix_diag_transform', + 'reduce_weighted_logsumexp', + 'softplus_inverse', + 'tridiag', 'normal_conjugates_known_scale_posterior', 'normal_conjugates_known_scale_predictive', - 'softplus_inverse', 'percentile', - 'reduce_weighted_logsumexp', 'assign_moving_mean_variance', 'assign_log_moving_mean_exp', 'moving_mean_variance', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py index da50037d6e128dc9a6c1214a2f780a0bfee112c7..e0d65c79b2654c2949de161d6317f218d11cab43 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py @@ -68,6 +68,18 @@ class AbsoluteValueTest(test.TestCase): sess.run(abs_bijector.inverse_log_det_jacobian([1.]), feed_dict={event_ndims: 1}) + def testNegativeYRaisesForInverseIfValidateArgs(self): + with self.test_session() as sess: + bijector = AbsoluteValue(event_ndims=0, validate_args=True) + with self.assertRaisesOpError("y was negative"): + sess.run(bijector.inverse(-1.)) + + def testNegativeYRaisesForILDJIfValidateArgs(self): + with self.test_session() as sess: + bijector = AbsoluteValue(event_ndims=0, validate_args=True) + with self.assertRaisesOpError("y was negative"): + sess.run(bijector.inverse_log_det_jacobian(-1.)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py index 0738754b217e5842bd0fa516915f14926083d321..405ddd292cacd8ace87d6caeebf3e8cfc347c22d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py @@ -72,7 +72,7 @@ class AffineLinearOperatorTest(test.TestCase): [3, -2, 0], [4, 3, 2]]], dtype=np.float32) - scale = linalg.LinearOperatorTriL(tril, is_non_singular=True) + scale = linalg.LinearOperatorLowerTriangular(tril, is_non_singular=True) affine = AffineLinearOperator( shift=shift, scale=scale, validate_args=True) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9a905980c7581a86bbcda8c6c726da57c09fe4f8 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py @@ -0,0 +1,70 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats + +from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import Gumbel +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class GumbelBijectorTest(test.TestCase): + """Tests correctness of the Gumbel bijector.""" + + def testBijector(self): + with self.test_session(): + loc = 0.3 + scale = 5. + bijector = Gumbel(loc=loc, scale=scale, event_ndims=1, validate_args=True) + self.assertEqual("gumbel", bijector.name) + x = np.array([[[-3.], [0.], [0.5], [4.2], [12.]]], dtype=np.float32) + # Gumbel distribution + gumbel_dist = stats.gumbel_r(loc=loc, scale=scale) + y = gumbel_dist.cdf(x).astype(np.float32) + self.assertAllClose(y, bijector.forward(x).eval()) + self.assertAllClose(x, bijector.inverse(y).eval()) + self.assertAllClose( + # We should lose a dimension from calculating the determinant of the + # jacobian. + np.squeeze(gumbel_dist.logpdf(x), axis=2), + bijector.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(y).eval(), + bijector.forward_log_det_jacobian(x).eval(), + rtol=1e-4, + atol=0.) + + def testScalarCongruency(self): + with self.test_session(): + assert_scalar_congruency( + Gumbel(loc=0.3, scale=20.), lower_x=1., upper_x=100., rtol=0.02) + + def testBijectiveAndFinite(self): + with self.test_session(): + bijector = Gumbel(loc=0., scale=3.0, event_ndims=0, validate_args=True) + x = np.linspace(-10., 10., num=10).astype(np.float32) + y = np.linspace(0.01, 0.99, num=10).astype(np.float32) + assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py new file mode 100644 index 0000000000000000000000000000000000000000..25a9b6f5fe2ed6d218d6b44650fce17fa89c0664 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -0,0 +1,153 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MaskedAutoregressiveFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import masked_autoregressive_default_template +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import MaskedAutoregressiveFlow +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import _gen_mask +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.platform import test + + +class GenMaskTest(test.TestCase): + + def test346Exclusive(self): + expected_mask = np.array( + [[0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 0, 0]]) + mask = _gen_mask(num_blocks=3, n_in=4, n_out=6, mask_type="exclusive") + self.assertAllEqual(expected_mask, mask) + + def test346Inclusive(self): + expected_mask = np.array( + [[1, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 0]]) + mask = _gen_mask(num_blocks=3, n_in=4, n_out=6, mask_type="inclusive") + self.assertAllEqual(expected_mask, mask) + + +class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers, + test.TestCase): + + @property + def _autoregressive_flow_kwargs(self): + return { + "shift_and_log_scale_fn": masked_autoregressive_default_template( + hidden_layers=[2], shift_only=False), + "is_constant_jacobian": False, + } + + def testBijector(self): + x_ = np.arange(3 * 4 * 2).astype(np.float32).reshape(3, 4, 2) + with self.test_session() as sess: + ma = MaskedAutoregressiveFlow( + validate_args=True, + **self._autoregressive_flow_kwargs) + x = constant_op.constant(x_) + forward_x = ma.forward(x) + # Use identity to invalidate cache. + inverse_y = ma.inverse(array_ops.identity(forward_x)) + fldj = ma.forward_log_det_jacobian(x) + # Use identity to invalidate cache. + ildj = ma.inverse_log_det_jacobian(array_ops.identity(forward_x)) + variables.global_variables_initializer().run() + [ + forward_x_, + inverse_y_, + ildj_, + fldj_, + ] = sess.run([ + forward_x, + inverse_y, + ildj, + fldj, + ]) + self.assertEqual("masked_autoregressive_flow", ma.name) + self.assertAllClose(forward_x_, forward_x_, rtol=1e-6, atol=0.) + self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=0.) + self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.) + + def testMutuallyConsistent(self): + dims = 4 + with self.test_session() as sess: + ma = MaskedAutoregressiveFlow( + validate_args=True, + **self._autoregressive_flow_kwargs) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=ma, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=1., + center=0., + rtol=0.02) + + def testInvertMutuallyConsistent(self): + dims = 4 + with self.test_session() as sess: + ma = Invert(MaskedAutoregressiveFlow( + validate_args=True, + **self._autoregressive_flow_kwargs)) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=ma, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=1., + center=0., + rtol=0.02) + + +class MaskedAutoregressiveFlowShiftOnlyTest(MaskedAutoregressiveFlowTest): + + @property + def _autoregressive_flow_kwargs(self): + return { + "shift_and_log_scale_fn": masked_autoregressive_default_template( + hidden_layers=[2], shift_only=True), + "is_constant_jacobian": True, + } + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py new file mode 100644 index 0000000000000000000000000000000000000000..54590de373441c32cc3214cb04d45cfc2d1807ed --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py @@ -0,0 +1,87 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Permute bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.permute import Permute +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.platform import test + + +class PermuteBijectorTest(test.TestCase): + """Tests correctness of the Permute bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testBijector(self): + expected_permutation = np.int32([2, 0, 1]) + expected_x = np.random.randn(4, 2, 3) + expected_y = expected_x[..., expected_permutation] + + with self.test_session() as sess: + permutation_ph = array_ops.placeholder(dtype=dtypes.int32) + bijector = Permute( + permutation=permutation_ph, + validate_args=True) + [ + permutation_, + x_, + y_, + fldj, + ildj, + ] = sess.run([ + bijector.permutation, + bijector.inverse(expected_y), + bijector.forward(expected_x), + bijector.forward_log_det_jacobian(expected_x), + bijector.inverse_log_det_jacobian(expected_y), + ], feed_dict={permutation_ph: expected_permutation}) + self.assertEqual("permute", bijector.name) + self.assertAllEqual(expected_permutation, permutation_) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + self.assertAllClose(0., fldj, rtol=1e-6, atol=0) + self.assertAllClose(0., ildj, rtol=1e-6, atol=0) + + def testRaisesOpError(self): + with self.test_session() as sess: + with self.assertRaisesOpError("Permutation over `d` must contain"): + permutation_ph = array_ops.placeholder(dtype=dtypes.int32) + bijector = Permute( + permutation=permutation_ph, + validate_args=True) + sess.run(bijector.inverse([1.]), + feed_dict={permutation_ph: [1, 2]}) + + def testBijectiveAndFinite(self): + permutation = np.int32([2, 0, 1]) + x = np.random.randn(4, 2, 3) + y = x[..., permutation] + with self.test_session(): + bijector = Permute( + permutation=permutation, + validate_args=True) + assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py new file mode 100644 index 0000000000000000000000000000000000000000..38b3a23c2d684a6f89b7c4be4a763c649bf4de15 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -0,0 +1,242 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Reshape Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.platform import test + + +class ReshapeBijectorTest(test.TestCase): + """Tests correctness of the reshape transformation.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testBijector(self): + """Do a basic sanity check of forward, inverse, jacobian.""" + expected_x = np.random.randn(4, 3, 2) + expected_y = np.reshape(expected_x, [4, 6]) + + with self.test_session() as sess: + bijector = Reshape( + event_shape_out=[6,], + event_shape_in=[3, 2], + validate_args=True) + (x_, + y_, + fldj_, + ildj_) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + bijector.forward_log_det_jacobian(expected_x), + bijector.inverse_log_det_jacobian(expected_y), + )) + self.assertEqual("reshape", bijector.name) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + self.assertAllClose(0., fldj_, rtol=1e-6, atol=0) + self.assertAllClose(0., ildj_, rtol=1e-6, atol=0) + + def testEventShapeDynamicNdims(self): + """Check forward/inverse shape methods with dynamic ndims.""" + + shape_in = tensor_shape.TensorShape([6,]) + shape_in_ph = array_ops.placeholder(dtype=dtypes.int32) + + shape_out = tensor_shape.TensorShape([2, 3]) + shape_out_ph = array_ops.placeholder(dtype=dtypes.int32) + + bijector = Reshape( + event_shape_out=shape_out_ph, + event_shape_in=shape_in_ph, validate_args=True) + + # using the _tensor methods, we should always get a fully-specified + # result since these are evaluated at graph runtime. + with self.test_session() as sess: + (shape_out_, + shape_in_) = sess.run(( + bijector.forward_event_shape_tensor(shape_in), + bijector.inverse_event_shape_tensor(shape_out), + ), feed_dict={ + shape_in_ph: shape_in, + shape_out_ph: shape_out, + }) + self.assertAllEqual(shape_out, shape_out_) + self.assertAllEqual(shape_in, shape_in_) + + def testEventShapeDynamic(self): + """Check shape methods with static ndims but dynamic shape.""" + + shape_in = tensor_shape.TensorShape([6,]) + shape_in_partial = tensor_shape.TensorShape([None,]) + shape_in_ph = array_ops.placeholder( + shape=[1,], dtype=dtypes.int32) + + shape_out = tensor_shape.TensorShape([2, 3]) + shape_out_partial = tensor_shape.TensorShape([None, None]) + shape_out_ph = array_ops.placeholder( + shape=[2,], dtype=dtypes.int32) + + bijector = Reshape( + event_shape_out=shape_out_ph, + event_shape_in=shape_in_ph, + validate_args=True) + + # if event shapes are not statically available, should + # return partially-specified TensorShapes. + self.assertAllEqual( + bijector.forward_event_shape(shape_in).as_list(), + shape_out_partial.as_list()) + self.assertAllEqual( + bijector.inverse_event_shape(shape_out).as_list(), + shape_in_partial.as_list()) + + # using the _tensor methods, we should always get a fully-specified + # result since these are evaluated at graph runtime. + with self.test_session() as sess: + (shape_out_, + shape_in_) = sess.run(( + bijector.forward_event_shape_tensor(shape_in), + bijector.inverse_event_shape_tensor(shape_out), + ), feed_dict={ + shape_in_ph: shape_in, + shape_out_ph: shape_out, + }) + self.assertAllEqual(shape_out, shape_out_) + self.assertAllEqual(shape_in, shape_in_) + + def testEventShapeStatic(self): + """Check shape methods when shape is statically known.""" + + shape_in = tensor_shape.TensorShape([6,]) + shape_out = tensor_shape.TensorShape([2, 3]) + + bijector_static = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + # test that forward_ and inverse_event_shape do sensible things + # when shapes are statically known. + self.assertEqual( + bijector_static.forward_event_shape(shape_in), + shape_out) + self.assertEqual( + bijector_static.inverse_event_shape(shape_out), + shape_in) + + with self.test_session() as sess: + (shape_out_static_, + shape_in_static_, + ) = sess.run(( + bijector_static.forward_event_shape_tensor(shape_in), + bijector_static.inverse_event_shape_tensor(shape_out), + )) + self.assertAllEqual(shape_out, shape_out_static_) + self.assertAllEqual(shape_in, shape_in_static_) + + def testScalarReshape(self): + """Test reshaping to and from a scalar shape ().""" + + expected_x = np.random.randn(4, 3, 1) + expected_y = np.reshape(expected_x, [4, 3]) + + expected_x_scalar = np.random.randn(1,) + expected_y_scalar = expected_x_scalar[0] + + with self.test_session() as sess: + bijector = Reshape( + event_shape_out=[], + event_shape_in=[1,], validate_args=True) + + (x_, + y_, + x_scalar_, + y_scalar_ + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + bijector.inverse(expected_y_scalar), + bijector.forward(expected_x_scalar), + )) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0) + + def testRaisesOpError(self): + x1 = np.random.randn(4, 2, 3) + x2 = np.random.randn(4, 3, 2) + x3 = np.random.randn(4, 5, 1, 1) + + with self.test_session() as sess: + shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32) + bijector = Reshape( + event_shape_out=shape_out_ph, + event_shape_in=shape_in_ph, + validate_args=True) + + with self.assertRaisesOpError( + "Input `event_shape` does not match `event_shape_in`."): + sess.run(bijector.forward(x2), + feed_dict={shape_out_ph: [1, 6, 1], + shape_in_ph: [2, 3]}) + + with self.assertRaisesOpError( + "event_shape_out entries must be positive."): + sess.run(bijector.forward(x1), + feed_dict={shape_out_ph: [-1, -1, 6], + shape_in_ph: [2, 3]}) + + # test that *all* methods check basic assertions + fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]} + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.forward(x1), feed_dict=fd_mismatched) + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.inverse(x3), feed_dict=fd_mismatched) + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.inverse_log_det_jacobian(x3), + feed_dict=fd_mismatched) + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.forward_log_det_jacobian(x1), + feed_dict=fd_mismatched) + + def testBijectiveAndFinite(self): + x = np.random.randn(4, 2, 3) + y = np.reshape(x, [4, 1, 2, 3]) + with self.test_session(): + bijector = Reshape( + event_shape_in=[2, 3], + event_shape_out=[1, 2, 3], + validate_args=True) + assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 230dd93a2a807cc14394e3c747c208c1f95b194d..172c180a44229089f06f250a872bc47a89991cf0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -41,7 +41,7 @@ class SinhArcsinhBijectorTest(test.TestCase): tailweight=tailweight, event_ndims=1, validate_args=True) - self.assertEqual("sinh_arcsinh", bijector.name) + self.assertEqual("SinhArcsinh", bijector.name) x = np.array([[[-2.01], [2.], [1e-4]]]).astype(np.float32) y = np.sinh((np.arcsinh(x) + skewness) * tailweight) self.assertAllClose(y, bijector.forward(x).eval()) @@ -170,6 +170,12 @@ class SinhArcsinhBijectorTest(test.TestCase): with self.assertRaisesOpError("not positive"): SinhArcsinh(tailweight=0., validate_args=True).forward(1.0).eval() + def testDefaultDtypeIsFloat32(self): + with self.test_session(): + bijector = SinhArcsinh() + self.assertEqual(bijector.tailweight.dtype, np.float32) + self.assertEqual(bijector.skewness.dtype, np.float32) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index d10312d6670b07caf455b9c09c50bb713711d77d..2d74aa1f320149d0f7ef9e9c52b8c7053c2f74d7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -23,11 +23,11 @@ import itertools import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.contrib.linalg.python.ops import linear_operator_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops.linalg import linear_operator_diag import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py index 7a321db4b296e0f1f09874043a4568e6809f10fc..8e23a3ab8fd1ea0432e2bacd1a7c945d09105bf2 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py @@ -23,8 +23,10 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import independent as independent_lib from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib -from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -41,8 +43,10 @@ def try_import(name): # pylint: disable=invalid-name stats = try_import("scipy.stats") -class ProductDistributionTest( - test_util.VectorDistributionTestHelpers, test.TestCase): +class ProductDistributionTest(test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(42) def testSampleAndLogProbUnivariate(self): loc = np.float32([-1., 1]) @@ -50,7 +54,7 @@ class ProductDistributionTest( with self.test_session() as sess: ind = independent_lib.Independent( distribution=normal_lib.Normal(loc=loc, scale=scale), - reduce_batch_ndims=1) + reinterpreted_batch_ndims=1) x = ind.sample([4, 5]) log_prob_x = ind.log_prob(x) @@ -73,7 +77,7 @@ class ProductDistributionTest( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=loc, scale_identity_multiplier=scale), - reduce_batch_ndims=1) + reinterpreted_batch_ndims=1) x = ind.sample([4, 5]) log_prob_x = ind.log_prob(x) @@ -98,7 +102,7 @@ class ProductDistributionTest( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=loc, scale_identity_multiplier=scale), - reduce_batch_ndims=1) + reinterpreted_batch_ndims=1) x = ind.sample(int(n_samp), seed=42) sample_mean = math_ops.reduce_mean(x, axis=0) @@ -122,6 +126,59 @@ class ProductDistributionTest( self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.) self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.) + def _testMnistLike(self, static_shape): + sample_shape = [4, 5] + batch_shape = [10] + image_shape = [28, 28, 1] + logits = 3 * self._rng.random_sample( + batch_shape + image_shape).astype(np.float32) - 1 + + def expected_log_prob(x, logits): + return (x * logits - np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1) + + with self.test_session() as sess: + logits_ph = array_ops.placeholder( + dtypes.float32, shape=logits.shape if static_shape else None) + ind = independent_lib.Independent( + distribution=bernoulli_lib.Bernoulli(logits=logits_ph)) + x = ind.sample(sample_shape) + log_prob_x = ind.log_prob(x) + [ + x_, + actual_log_prob_x, + ind_batch_shape, + ind_event_shape, + x_shape, + log_prob_x_shape, + ] = sess.run([ + x, + log_prob_x, + ind.batch_shape_tensor(), + ind.event_shape_tensor(), + array_ops.shape(x), + array_ops.shape(log_prob_x), + ], feed_dict={logits_ph: logits}) + + if static_shape: + ind_batch_shape = ind.batch_shape + ind_event_shape = ind.event_shape + x_shape = x.shape + log_prob_x_shape = log_prob_x.shape + + self.assertAllEqual(batch_shape, ind_batch_shape) + self.assertAllEqual(image_shape, ind_event_shape) + self.assertAllEqual(sample_shape + batch_shape + image_shape, x_shape) + self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape) + self.assertAllClose(expected_log_prob(x_, logits), + actual_log_prob_x, + rtol=1e-6, atol=0.) + + def testMnistLikeStaticShape(self): + self._testMnistLike(static_shape=True) + + def testMnistLikeDynamicShape(self): + self._testMnistLike(static_shape=False) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py index 47ac412500d36df999225d94be0ecb7cccf75723..ece6bc077d9e21502fdfd01300a9d3e9f2c9c380 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py @@ -23,94 +23,116 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import mixture_same_family as mixture_same_family_lib from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test -class MixtureSameFamilyTest( - test_util.VectorDistributionTestHelpers, test.TestCase): +class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers, + test.TestCase): def testSampleAndLogProbUnivariateShapes(self): with self.test_session(): gm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=[0.3, 0.7]), + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=normal_lib.Normal( - loc=[-1., 1], - scale=[0.1, 0.5])) - x = gm.sample([4, 5]) + loc=[-1., 1], scale=[0.1, 0.5])) + x = gm.sample([4, 5], seed=42) log_prob_x = gm.log_prob(x) self.assertEqual([4, 5], x.shape) self.assertEqual([4, 5], log_prob_x.shape) def testSampleAndLogProbShapesBroadcastMix(self): mix_probs = np.float32([.3, .7]) - bern_probs = np.float32([[.4, .6], - [.25, .75]]) + bern_probs = np.float32([[.4, .6], [.25, .75]]) with self.test_session(): bm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=mix_probs), - components_distribution=bernoulli_lib.Bernoulli( - probs=bern_probs)) - x = bm.sample([4, 5]) + mixture_distribution=categorical_lib.Categorical(probs=mix_probs), + components_distribution=bernoulli_lib.Bernoulli(probs=bern_probs)) + x = bm.sample([4, 5], seed=42) log_prob_x = bm.log_prob(x) x_ = x.eval() self.assertEqual([4, 5, 2], x.shape) self.assertEqual([4, 5, 2], log_prob_x.shape) - self.assertAllEqual(np.ones_like(x_, dtype=np.bool), - np.logical_or(x_ == 0., x_ == 1.)) + self.assertAllEqual( + np.ones_like(x_, dtype=np.bool), np.logical_or(x_ == 0., x_ == 1.)) def testSampleAndLogProbMultivariateShapes(self): with self.test_session(): gm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=[0.3, 0.7]), + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( - loc=[[-1., 1], [1, -1]], - scale_identity_multiplier=[1., 0.5])) - x = gm.sample([4, 5]) + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) + x = gm.sample([4, 5], seed=42) log_prob_x = gm.log_prob(x) self.assertEqual([4, 5, 2], x.shape) self.assertEqual([4, 5], log_prob_x.shape) + def testSampleAndLogProbBatchMultivariateShapes(self): + with self.test_session(): + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), + components_distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=[[[-1., 1], + [1, -1]], + [[0., 1], + [1, 0]]], + scale_identity_multiplier=[1., 0.5])) + x = gm.sample([4, 5], seed=42) + log_prob_x = gm.log_prob(x) + self.assertEqual([4, 5, 2, 2], x.shape) + self.assertEqual([4, 5, 2], log_prob_x.shape) + def testSampleConsistentLogProb(self): with self.test_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=[0.3, 0.7]), + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( - loc=[[-1., 1], [1, -1]], - scale_identity_multiplier=[1., 0.5])) + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess, gm, radius=1., center=[-1., 1], rtol=0.02) + sess.run, gm, radius=1., center=[-1., 1], rtol=0.02) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess, gm, radius=1., center=[1., -1], rtol=0.02) + sess.run, gm, radius=1., center=[1., -1], rtol=0.02) + + def testLogCdf(self): + with self.test_session() as sess: + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), + components_distribution=normal_lib.Normal( + loc=[-1., 1], scale=[0.1, 0.5])) + x = gm.sample(10, seed=42) + actual_log_cdf = gm.log_cdf(x) + expected_log_cdf = math_ops.reduce_logsumexp( + (gm.mixture_distribution.logits + + gm.components_distribution.log_cdf(x[..., array_ops.newaxis])), + axis=1) + actual_log_cdf_, expected_log_cdf_ = sess.run([ + actual_log_cdf, expected_log_cdf]) + self.assertAllClose(actual_log_cdf_, expected_log_cdf_, + rtol=1e-6, atol=0.0) def testSampleConsistentMeanCovariance(self): with self.test_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=[0.3, 0.7]), + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( - loc=[[-1., 1], [1, -1]], - scale_identity_multiplier=[1., 0.5])) - self.run_test_sample_consistent_mean_covariance(sess, gm) + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) + self.run_test_sample_consistent_mean_covariance(sess.run, gm) def testVarianceConsistentCovariance(self): with self.test_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=[0.3, 0.7]), + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( - loc=[[-1., 1], [1, -1]], - scale_identity_multiplier=[1., 0.5])) + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) cov_, var_ = sess.run([gm.covariance(), gm.variance()]) self.assertAllClose(cov_.diagonal(), var_, atol=0.) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py index 43e302475b49ef5245ba324c35ca294b51a566b6..933756aa8e12cca4c42eb98d9193512bbf2ad585 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py @@ -289,6 +289,18 @@ class MultivariateNormalDiagTest(test.TestCase): self.assertListEqual(mvn.batch_shape.as_list(), [2, 3]) self.assertListEqual(mvn.event_shape.as_list(), [None]) + def testKLDivIdenticalGradientDefined(self): + dims = 3 + with self.test_session() as sess: + loc = array_ops.zeros([dims], dtype=dtypes.float32) + mvn = ds.MultivariateNormalDiag( + loc=loc, + scale_diag=np.ones([dims], dtype=np.float32)) + g = gradients_impl.gradients(ds.kl_divergence(mvn, mvn), loc) + g_ = sess.run(g) + self.assertAllEqual(np.ones_like(g_, dtype=np.bool), + np.isfinite(g_)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py index c1a74c6483b9843c609ac94054a8c27476f7d7ff..37edaa42cdc202cda4aa173752a3639792f96daf 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py @@ -241,6 +241,28 @@ class NegativeBinomialTest(test.TestCase): atol=0., rtol=.02) + def testLogProbOverflow(self): + with self.test_session() as sess: + logits = np.float32([20., 30., 40.]) + total_count = np.float32(1.) + x = np.float32(0.) + nb = negative_binomial.NegativeBinomial( + total_count=total_count, logits=logits) + log_prob_ = sess.run(nb.log_prob(x)) + self.assertAllEqual(np.ones_like(log_prob_, dtype=np.bool), + np.isfinite(log_prob_)) + + def testLogProbUnderflow(self): + with self.test_session() as sess: + logits = np.float32([-90, -100, -110]) + total_count = np.float32(1.) + x = np.float32(0.) + nb = negative_binomial.NegativeBinomial( + total_count=total_count, logits=logits) + log_prob_ = sess.run(nb.log_prob(x)) + self.assertAllEqual(np.ones_like(log_prob_, dtype=np.bool), + np.isfinite(log_prob_)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py index 7cb46bb2367658518c98baaa14947b5ad837ff12..3c0147b8cf6e1b6a2791e85c0c0997992445fa7e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py @@ -18,8 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.distributions.python.ops import poisson_lognormal from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -32,60 +36,80 @@ class PoissonLogNormalQuadratureCompoundTest( pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=-2., scale=1.1, - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( - sess, pln, rtol=0.1) + sess.run, pln, rtol=0.1) def testMeanVariance(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=0., scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( - sess, pln, rtol=0.02) + sess.run, pln, rtol=0.02) def testSampleProbConsistentBroadcastScalar(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( - sess, pln, rtol=0.1, atol=0.01) + sess.run, pln, rtol=0.1, atol=0.01) def testMeanVarianceBroadcastScalar(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( - sess, pln, rtol=0.1, atol=0.01) + sess.run, pln, rtol=0.1, atol=0.01) def testSampleProbConsistentBroadcastBoth(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[[0.], [-0.5]], scale=[[1., 0.9]], - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( - sess, pln, rtol=0.1, atol=0.08) + sess.run, pln, rtol=0.1, atol=0.08) def testMeanVarianceBroadcastBoth(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[[0.], [-0.5]], scale=[[1., 0.9]], - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( - sess, pln, rtol=0.1, atol=0.01) + sess.run, pln, rtol=0.1, atol=0.01) + + def testSampleProbConsistentDynamicQuadrature(self): + with self.test_session() as sess: + qgrid = array_ops.placeholder(dtype=dtypes.float32) + qprobs = array_ops.placeholder(dtype=dtypes.float32) + g, p = np.polynomial.hermite.hermgauss(deg=10) + pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( + loc=-2., + scale=1.1, + quadrature_grid_and_probs=(g, p), + validate_args=True) + self.run_test_sample_consistent_log_prob( + lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}), + pln, rtol=0.1) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py b/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py index 8c8363fe3f5159ed4def82472df8cb8ff518b05c..faae9da6ad812c629a2bdbb985fdd6f78a0860e1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py @@ -164,6 +164,14 @@ class RelaxedOneHotCategoricalTest(test.TestCase): self.assertAllEqual([5, 3], dist.sample(5).eval(feed_dict=feed_dict).shape) + def testDTypes(self): + # check that sampling and log_prob work for a range of dtypes + with self.test_session(): + for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): + logits = random_ops.random_uniform(shape=[3, 3], dtype=dtype) + dist = relaxed_onehot_categorical.RelaxedOneHotCategorical( + temperature=0.5, logits=logits) + dist.log_prob(dist.sample()) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py index 8ea3a592552b86495b88640df1be732f9a0b0778..88b48736dd55270fb4e149ae1560911179e446e9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py @@ -200,6 +200,22 @@ class SinhArcsinhTest(test.TestCase): sasnorm_samps = sess.run(sasnorm.sample(10000, seed=4)) np.testing.assert_array_less(loc, sasnorm_samps.mean(axis=0)) + def test_pdf_reflected_for_negative_skewness(self): + with self.test_session() as sess: + sas_pos_skew = ds.SinhArcsinh( + loc=0., + scale=1., + skewness=2., + validate_args=True) + sas_neg_skew = ds.SinhArcsinh( + loc=0., + scale=1., + skewness=-2., + validate_args=True) + x = np.linspace(-2, 2, num=5).astype(np.float32) + self.assertAllClose( + *sess.run([sas_pos_skew.prob(x), sas_neg_skew.prob(x[::-1])])) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 4001530f6654a656891ebc15397cc3f618711bd8..103d8e186221e879d1734a097114708429f725bd 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -116,6 +116,18 @@ class TransformedDistributionTest(test.TestCase): np.log(sp_normal.pdf(2.13) + sp_normal.pdf(-2.13)), abs_normal.log_prob(2.13).eval()) + def testQuantile(self): + with self.test_session() as sess: + logit_normal = self._cls()( + distribution=ds.Normal(loc=0., scale=1.), + bijector=bs.Sigmoid(), + validate_args=True) + grid = [0., 0.25, 0.5, 0.75, 1.] + q = logit_normal.quantile(grid) + cdf = logit_normal.cdf(q) + cdf_ = sess.run(cdf) + self.assertAllClose(grid, cdf_, rtol=1e-6, atol=0.) + def testCachedSamples(self): exp_forward_only = bs.Exp(event_ndims=0) exp_forward_only._inverse = self._make_unimplemented( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py index 070ee61be314905239e11e8ed3b39f6ffa7510a7..de4a221f7badca8267a81d612a57137c676ff052 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py @@ -22,9 +22,11 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_diag as linop_diag_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_identity as linop_identity_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib +from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.platform import test @@ -55,10 +57,10 @@ class VectorDiffeomixtureTest( validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.005) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.005) def testSampleProbConsistentBroadcastMixNonStandardBase(self): with self.test_session() as sess: @@ -83,10 +85,10 @@ class VectorDiffeomixtureTest( validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=2., center=1., rtol=0.006) + sess.run, vdm, radius=2., center=1., rtol=0.006) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=4., center=3., rtol=0.009) + sess.run, vdm, radius=4., center=3., rtol=0.009) def testSampleProbConsistentBroadcastMixBatch(self): with self.test_session() as sess: @@ -114,10 +116,10 @@ class VectorDiffeomixtureTest( validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.005) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.005) def testMeanCovarianceNoBatch(self): with self.test_session() as sess: @@ -141,7 +143,7 @@ class VectorDiffeomixtureTest( ], validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.06) def testMeanCovarianceNoBatchUncenteredNonStandardBase(self): with self.test_session() as sess: @@ -165,7 +167,7 @@ class VectorDiffeomixtureTest( ], validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025) + sess.run, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025) def testMeanCovarianceBatch(self): with self.test_session() as sess: @@ -192,7 +194,40 @@ class VectorDiffeomixtureTest( ], validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.06) + + def testSampleProbConsistentDynamicQuadrature(self): + with self.test_session() as sess: + qgrid = array_ops.placeholder(dtype=dtypes.float32) + qprobs = array_ops.placeholder(dtype=dtypes.float32) + g, p = np.polynomial.hermite.hermgauss(deg=8) + dims = 4 + vdm = vector_diffeomixture_lib.VectorDiffeomixture( + mix_loc=[[0.], [1.]], + mix_scale=[1.], + distribution=normal_lib.Normal(0., 1.), + loc=[ + None, + np.float32([2.]*dims), + ], + scale=[ + linop_identity_lib.LinearOperatorScaledIdentity( + num_rows=dims, + multiplier=np.float32(1.1), + is_positive_definite=True), + linop_diag_lib.LinearOperatorDiag( + diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), + is_positive_definite=True), + ], + quadrature_grid_and_probs=(g, p), + validate_args=True) + # Ball centered at component0's mean. + sess_run_fn = lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}) + self.run_test_sample_consistent_log_prob( + sess_run_fn, vdm, radius=2., center=0., rtol=0.005) + # Larger ball centered at component1's mean. + self.run_test_sample_consistent_log_prob( + sess_run_fn, vdm, radius=4., center=2., rtol=0.005) # TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent, # (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py index a7140cd98b4fbf2b6fc6c505a884619957e6eef1..2bc6a926dd66fd2b5796576c723345ca2014aad6 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py @@ -210,15 +210,15 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers, validate_args=True) self.run_test_sample_consistent_log_prob( - sess, sasnorm, radius=1.0, center=0., rtol=0.1) + sess.run, sasnorm, radius=1.0, center=0., rtol=0.1) self.run_test_sample_consistent_log_prob( - sess, + sess.run, sasnorm, radius=1.0, center=-0.15, rtol=0.1) self.run_test_sample_consistent_log_prob( - sess, + sess.run, sasnorm, radius=1.0, center=0.15, @@ -237,20 +237,36 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers, validate_args=True) self.run_test_sample_consistent_log_prob( - sess, sasnorm, radius=1.0, center=0., rtol=0.1) + sess.run, sasnorm, radius=1.0, center=0., rtol=0.1) self.run_test_sample_consistent_log_prob( - sess, + sess.run, sasnorm, radius=1.0, center=-0.15, rtol=0.1) self.run_test_sample_consistent_log_prob( - sess, + sess.run, sasnorm, radius=1.0, center=0.15, rtol=0.1) + def test_pdf_reflected_for_negative_skewness(self): + with self.test_session() as sess: + sas_pos_skew = ds.VectorSinhArcsinhDiag( + loc=[0.], + scale_identity_multiplier=1., + skewness=2., + validate_args=True) + sas_neg_skew = ds.VectorSinhArcsinhDiag( + loc=[0.], + scale_identity_multiplier=1., + skewness=-2., + validate_args=True) + x = np.linspace(-2, 2, num=5).astype(np.float32).reshape(5, 1) + self.assertAllClose( + *sess.run([sas_pos_skew.prob(x), sas_neg_skew.prob(x[::-1])])) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 4541701109aa4d64890b3bfe53dad6d5830c2ac6..bc0ec7f195af009c87020ce8c4ea18f2e713759a 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -22,16 +22,23 @@ @@CholeskyOuterProduct @@ConditionalBijector @@Exp +@@Gumbel @@Identity @@Inline @@Invert +@@MaskedAutoregressiveFlow +@@Permute @@PowerTransform +@@Reshape @@Sigmoid @@SigmoidCentered @@SinhArcsinh @@SoftmaxCentered @@Softplus @@Weibull + +@@masked_autoregressive_default_template +@@masked_dense """ from __future__ import absolute_import @@ -47,9 +54,13 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * +from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * +from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * +from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py index 065a049cf7ded10b56f48ed2228b00796c72d46d..b84502003ab6c0c4ffdda21eea162f441509e1fa 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py @@ -35,7 +35,17 @@ class AbsoluteValue(bijector.Bijector): """Computes `Y = g(X) = Abs(X)`, element-wise. This non-injective bijector allows for transformations of scalar distributions - with the absolute value function. + with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. + + * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse + `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. + * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse + (the set inverse is the singleton `{0}`), but "works" in conjunction with + `TransformedDistribution` to produce a left semi-continuous pdf. + * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the + wrong thing, `-y, y`. This is done for efficiency. If + `validate_args == True`, `y < 0` will raise an exception. + ```python abs = ds.bijectors.AbsoluteValue() @@ -68,7 +78,8 @@ class AbsoluteValue(bijector.Bijector): with a particular draw from the distribution. Currently only zero is supported. validate_args: Python `bool` indicating whether arguments should be - checked for correctness. + checked for correctness, in particular whether inputs to `inverse` and + `inverse_log_det_jacobian` are non-negative. name: Python `str` name given to ops managed by this object. Raises: @@ -98,6 +109,10 @@ class AbsoluteValue(bijector.Bijector): return math_ops.abs(x) def _inverse(self, y): + if self.validate_args: + y = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + y) return -y, y def _inverse_log_det_jacobian(self, y): @@ -106,6 +121,10 @@ class AbsoluteValue(bijector.Bijector): # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] zeros = array_ops.zeros(batch_shape, dtype=y.dtype) + if self.validate_args: + zeros = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + zeros) return zeros, zeros @property diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py index f74d699a43ecada10f9b1aef61ba3a1e8472e5dc..05bb9c2f9bdf35e222c94db3491157893da64ebd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py @@ -326,7 +326,7 @@ class Affine(bijector.Bijector): shape_hint=shape_hint) if perturb_factor is not None: - return linalg.LinearOperatorUDVHUpdate( + return linalg.LinearOperatorLowRankUpdate( scale, u=perturb_factor, diag_update=perturb_diag, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py index ae380b5cb2bc39e06aa1e187c134d7e92f6cd92f..89043b1410370074f11f2cfa59b6b6663fa62521 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.contrib.linalg.python.ops import linear_operator from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -27,6 +26,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.linalg import linear_operator __all__ = [ @@ -66,7 +66,7 @@ class AffineLinearOperator(bijector.Bijector): Example Use: ```python - linalg = tf.contrib.linalg + linalg = tf.linalg x = [1., 2, 3] @@ -82,7 +82,7 @@ class AffineLinearOperator(bijector.Bijector): tril = [[1., 0, 0], [2, 1, 0], [3, 2, 1]] - scale = linalg.LinearOperatorTriL(tril) + scale = linalg.LinearOperatorLowerTriangular(tril) affine = AffineLinearOperator(shift, scale) # In this case, `forward` is equivalent to: # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py new file mode 100644 index 0000000000000000000000000000000000000000..cf37aa51115ed98ab263bc03bcb297a03432a7ae --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gumbel bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.distributions.python.ops.bijectors.gumbel_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["Gumbel"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..67f39785563255be0fe154aca3cbcf01c6a01e73 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py @@ -0,0 +1,124 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gumbel bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + +__all__ = [ + "Gumbel", +] + + +class Gumbel(bijector.Bijector): + """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + + This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): + + ```none + Y ~ Gumbel(loc, scale) + pdf(y; loc, scale) = exp( + -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale + ``` + """ + + def __init__(self, + loc=0., + scale=1., + event_ndims=0, + validate_args=False, + name="gumbel"): + """Instantiates the `Gumbel` bijector. + + Args: + loc: Float-like `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + scale: Positive Float-like `Tensor` that is the same dtype and is + broadcastable with `loc`. + This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[loc, scale]): + self._loc = ops.convert_to_tensor(loc, name="loc") + self._scale = ops.convert_to_tensor(scale, name="scale") + check_ops.assert_same_float_dtype([self._loc, self._scale]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, message="Argument scale was not positive") + ], self._scale) + + super(Gumbel, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def loc(self): + """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._loc + + @property + def scale(self): + """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._scale + + def _forward(self, x): + z = (x - self.loc) / self.scale + return math_ops.exp(-math_ops.exp(-z)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.loc - self.scale * math_ops.log(-math_ops.log(y)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + event_dims = self._event_dims_tensor(x) + z = (x - self.loc) / self.scale + return math_ops.reduce_sum( + -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, + constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py new file mode 100644 index 0000000000000000000000000000000000000000..132dc570f94719b6c71fb269866c943774481b7e --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -0,0 +1,33 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MaskedAutoregressiveFlow bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "MaskedAutoregressiveFlow", + "masked_dense", + "masked_autoregressive_default_template", +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..ae142883931274b594dbbafbe86bd71e75c621bc --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py @@ -0,0 +1,473 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MaskedAutoregressiveFlow bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import template as template_ops +from tensorflow.python.ops import variable_scope as variable_scope_lib +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ + "MaskedAutoregressiveFlow", + "masked_autoregressive_default_template", + "masked_dense", +] + + +class MaskedAutoregressiveFlow(bijector_lib.Bijector): + """Affine MaskedAutoregressiveFlow bijector for vector-valued events. + + The affine autoregressive flow [1] provides a relatively simple framework for + user-specified (deep) architectures to learn a distribution over vector-valued + events. Regarding terminology, + + "Autoregressive models decompose the joint density as a product of + conditionals, and model each conditional in turn. Normalizing flows + transform a base density (e.g. a standard Gaussian) into the target density + by an invertible transformation with tractable Jacobian." [1] + + In other words, the "autoregressive property" is equivalent to the + decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided + `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves + this property by zeroing out weights in its `masked_dense` layers. + + In the `tf.distributions` framework, a "normalizing flow" is implemented as a + `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" + is implemented using a `tf.while_loop` and a deep neural network (DNN) with + masked weights such that the autoregressive property is automatically met in + the `inverse`. + + A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the + (expensive) forward-mode calculation to draw samples and the (cheap) + reverse-mode calculation to compute log-probabilities. Conversely, a + `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses + the (expensive) forward-mode calculation to compute log-probabilities and the + (cheap) reverse-mode calculation to compute samples. See "Example Use" + [below] for more details. + + Given a `shift_and_log_scale_fn`, the forward and inverse transformations are + (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` + must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka + "alpha" [2]) such that each are broadcastable with the arguments to `forward` + and `inverse`, i.e., such that the calculations in `forward`, `inverse` + [below] are possible. + + For convenience, `masked_autoregressive_default_template` is offered as a + possible `shift_and_log_scale_fn` function. It implements the MADE + architecture [2]. MADE is a feed-forward network that computes a `shift` and + `log(scale)` using `masked_dense` layers in a deep neural network. Weights are + masked to ensure the autoregressive property. It is possible that this + architecture is suboptimal for your task. To build alternative networks, + either change the arguments to `masked_autoregressive_default_template`, use + the `masked_dense` function to roll-out your own, or use some other + architecture, e.g., using `tf.layers`. + + Warning: no attempt is made to validate that the `shift_and_log_scale_fn` + enforces the "autoregressive property". + + Assuming `shift_and_log_scale_fn` has valid shape and autoregressive + semantics, the forward transformation is, + + ```python + def forward(x): + y = zeros_like(x) + event_size = x.shape[-1] + for _ in range(event_size): + shift, log_scale = shift_and_log_scale_fn(y) + y = x * math_ops.exp(log_scale) + shift + return y + ``` + + and the inverse transformation is, + + ```python + def inverse(y): + shift, log_scale = shift_and_log_scale_fn(y) + return (y - shift) / math_ops.exp(log_scale) + ``` + + Notice that the `inverse` does not need a for-loop. This is because in the + forward pass each calculation of `shift` and `log_scale` is based on the `y` + calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is + equivalent to the scaling used in `forward` after `event_size` passes, i.e., + the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this + also proves the transform is bijective.) + + #### Example Use + + ```python + ds = tf.contrib.distributions + bs = tf.contrib.distributions.bijectors + + dims = 5 + + # A common choice for a normalizing flow is to use a Gaussian for the base + # distribution. (However, any continuous distribution would work.) E.g., + maf = ds.TransformedDistribution( + distribution=ds.Normal(loc=0., scale=1.), + bijector=bs.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=bs.masked_autoregressive_default_template( + hidden_layers=[512, 512])), + event_shape=[dims]) + + x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. + maf.log_prob(x) # Almost free; uses Bijector caching. + maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. + + # [1] also describes an "Inverse Autoregressive Flow", e.g., + iaf = ds.TransformedDistribution( + distribution=ds.Normal(loc=0., scale=1.), + bijector=bs.Invert(bs.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=bs.masked_autoregressive_default_template( + hidden_layers=[512, 512]))), + event_shape=[dims]) + + x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. + iaf.log_prob(x) # Almost free; uses Bijector caching. + iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. + + # In many (if not most) cases the default `shift_and_log_scale_fn` will be a + # poor choice. Here's an example of using a "shift only" version and with a + # different number/depth of hidden layers. + shift_only = True + maf_no_scale_hidden2 = ds.TransformedDistribution( + distribution=ds.Normal(loc=0., scale=1.), + bijector=bs.MaskedAutoregressiveFlow( + bs.masked_autoregressive_default_template( + hidden_layers=[32], + shift_only=shift_only), + is_constant_jacobian=shift_only), + event_shape=[dims]) + ``` + + [1]: "Masked Autoregressive Flow for Density Estimation." + George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. + https://arxiv.org/abs/1705.07057 + + [2]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + """ + + def __init__(self, + shift_and_log_scale_fn, + is_constant_jacobian=False, + validate_args=False, + name=None): + """Creates the MaskedAutoregressiveFlow bijector. + + Args: + shift_and_log_scale_fn: Python `callable` which computes `shift` and + `log_scale` from both the forward domain (`x`) and the inverse domain + (`y`). Calculation must respect the "autoregressive property" (see class + docstring). Suggested default + `masked_autoregressive_default_template(hidden_layers=...)`. + Typically the function contains `tf.Variables` and is wrapped using + `tf.make_template`. Returning `None` for either (both) `shift`, + `log_scale` is equivalent to (but more efficient than) returning zero. + is_constant_jacobian: Python `bool`. Default: `False`. When `True` the + implementation assumes `log_scale` does not depend on the forward domain + (`x`) or inverse domain (`y`) values. (No validation is made; + `is_constant_jacobian=False` is always safe but possibly computationally + inefficient.) + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + name = name or "masked_autoregressive_flow" + self._shift_and_log_scale_fn = shift_and_log_scale_fn + super(MaskedAutoregressiveFlow, self).__init__( + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + + def _forward(self, x): + event_size = array_ops.shape(x)[-1] + def _loop_body(index, y0): + """While-loop body for autoregression calculation.""" + # Set caching device to avoid re-getting the tf.Variable for every while + # loop iteration. + with variable_scope_lib.variable_scope( + variable_scope_lib.get_variable_scope()) as vs: + if vs.caching_device is None: + vs.set_caching_device(lambda op: op.device) + shift, log_scale = self._shift_and_log_scale_fn(y0) + y = x + if log_scale is not None: + y *= math_ops.exp(log_scale) + if shift is not None: + y += shift + return index + 1, y + _, y = control_flow_ops.while_loop( + cond=lambda index, _: index < event_size, + body=_loop_body, + loop_vars=[0, array_ops.zeros_like(x, name="y0")]) + return y + + def _inverse(self, y): + shift, log_scale = self._shift_and_log_scale_fn(y) + x = y + if shift is not None: + x -= shift + if log_scale is not None: + x *= math_ops.exp(-log_scale) + return x + + def _inverse_log_det_jacobian(self, y): + _, log_scale = self._shift_and_log_scale_fn(y) + if log_scale is None: + return constant_op.constant(0., dtype=y.dtype, name="ildj") + return -math_ops.reduce_sum(log_scale, axis=-1) + + +MASK_INCLUSIVE = "inclusive" +MASK_EXCLUSIVE = "exclusive" + + +def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): + """Generate the slices for building an autoregressive mask.""" + # TODO(b/67594795): Better support of dynamic shape. + slices = [] + col = 0 + d_in = n_in // num_blocks + d_out = n_out // num_blocks + row = d_out if mask_type == MASK_EXCLUSIVE else 0 + for _ in range(num_blocks): + row_slice = slice(row, None) + col_slice = slice(col, col + d_in) + slices.append([row_slice, col_slice]) + col += d_in + row += d_out + return slices + + +def _gen_mask(num_blocks, + n_in, + n_out, + mask_type=MASK_EXCLUSIVE, + dtype=dtypes.float32): + """Generate the mask for building an autoregressive dense layer.""" + # TODO(b/67594795): Better support of dynamic shape. + mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) + slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) + for [row_slice, col_slice] in slices: + mask[row_slice, col_slice] = 1 + return mask + + +def masked_dense(inputs, + units, + num_blocks=None, + exclusive=False, + kernel_initializer=None, + reuse=None, + name=None, + *args, + **kwargs): + """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. + + See [1] for detailed explanation. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + inputs: Tensor input. + units: Python `int` scalar representing the dimensionality of the output + space. + num_blocks: Python `int` scalar representing the number of blocks for the + MADE masks. + exclusive: Python `bool` scalar representing whether to zero the diagonal of + the mask, used for the first layer of a MADE. + kernel_initializer: Initializer function for the weight matrix. + If `None` (default), weights are initialized using the + `tf.glorot_random_initializer`. + reuse: Python `bool` scalar representing whether to reuse the weights of a + previous layer by the same name. + name: Python `str` used to describe ops managed by this function. + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + Output tensor. + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + # TODO(b/67594795): Better support of dynamic shape. + input_depth = inputs.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + + mask = _gen_mask(num_blocks, input_depth, units, + MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T + + if kernel_initializer is None: + kernel_initializer = init_ops.glorot_normal_initializer() + + def masked_initializer(shape, dtype=None, partition_info=None): + return mask * kernel_initializer(shape, dtype, partition_info) + + with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): + layer = layers.Dense( + units, + kernel_initializer=masked_initializer, + kernel_constraint=lambda x: mask * x, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse, + *args, + **kwargs) + return layer.apply(inputs) + + +def masked_autoregressive_default_template( + hidden_layers, + shift_only=False, + activation=nn_ops.relu, + log_scale_min_clip=-5., + log_scale_max_clip=3., + log_scale_clip_gradient=False, + name=None, + *args, + **kwargs): + """Build the MADE Model [1]. + + This will be wrapped in a make_template to ensure the variables are only + created once. It takes the input and returns the `loc` ("mu" [1]) and + `log_scale` ("alpha" [1]) from the MADE network. + + Warning: This function uses `masked_dense` to create randomly initialized + `tf.Variables`. It is presumed that these will be fit, just as you would any + other neural architecture which uses `tf.layers.dense`. + + #### About Hidden Layers: + + Each element of `hidden_layers` should be greater than the `input_depth` + (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the + neural network). This is necessary to ensure the autoregressivity property. + + #### About Clipping: + + This function also optionally clips the `log_scale` (but possibly not its + gradient). This is useful because if `log_scale` is too small/large it might + underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` + bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` + `bool` indicates whether the gradient should also be clipped. The default does + not clip the gradient; this is useful because it still provides gradient + information (for fitting) yet solves the numerical stability problem. I.e., + `log_scale_clip_gradient = False` means + `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual + `grad[clip(x)] exp(clip(x))`. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + hidden_layers: Python `list`-like of non-negative integer, scalars + indicating the number of units in each hidden layer. Default: `[512, 512]. + shift_only: Python `bool` indicating if only the `shift` term shall be + computed. Default: `False`. + activation: Activation function (callable). Explicitly setting to `None` + implies a linear activation. + log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The minimum value to clip by. Default: -5. + log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The maximum value to clip by. Default: 3. + log_scale_clip_gradient: Python `bool` indicating that the gradient of + `tf.clip_by_value` should be preserved. Default: `False`. + name: A name for ops managed by this function. Default: + "masked_autoregressive_default_template". + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + + with ops.name_scope(name, "masked_autoregressive_default_template", + values=[log_scale_min_clip, log_scale_max_clip]): + def _fn(x): + """MADE parameterized via `masked_autoregressive_default_template`.""" + # TODO(b/67594795): Better support of dynamic shape. + input_depth = x.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() + else array_ops.shape(x)) + for i, units in enumerate(hidden_layers): + x = masked_dense( + inputs=x, + units=units, + num_blocks=input_depth, + exclusive=True if i == 0 else False, + activation=activation, + *args, + **kwargs) + x = masked_dense( + inputs=x, + units=(1 if shift_only else 2) * input_depth, + num_blocks=input_depth, + activation=None, + *args, + **kwargs) + if shift_only: + x = array_ops.reshape(x, shape=input_shape) + return x, None + x = array_ops.reshape( + x, shape=array_ops.concat([input_shape, [2]], axis=0)) + shift, log_scale = array_ops.unstack(x, num=2, axis=-1) + which_clip = (math_ops.clip_by_value if log_scale_clip_gradient + else _clip_by_value_preserve_grad) + log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) + return shift, log_scale + return template_ops.make_template( + "masked_autoregressive_default_template", _fn) + + +def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): + """Clips input while leaving gradient unaltered.""" + with ops.name_scope(name, "clip_by_value_preserve_grad", + [x, clip_value_min, clip_value_max]): + clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) + return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py new file mode 100644 index 0000000000000000000000000000000000000000..a187ce22d686ee1203802ae2bfe64b0e1a3ea850 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Permute bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.distributions.python.ops.bijectors.permute_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["Permute"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d8f2f41b28a88208a19824377f93882b767f03 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py @@ -0,0 +1,138 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Permutation bijectors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ + "Permute", +] + + +class Permute(bijector_lib.Bijector): + """Permutes the rightmost dimension of a `Tensor`. + + ```python + bs = tf.contrib.distributions.bijectors + + reverse = bs.Permute(permutation=[2, 1, 0]) + + reverse.forward([-1., 0., 1.]) + # ==> [1., 0., -1] + + reverse.inverse([1., 0., -1]) + # ==> [-1., 0., 1.] + + reverse.forward_log_det_jacobian(any_value) + # ==> 0. + + reverse.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + Warning: `tf.estimator` may repeatedly build the graph thus + `Permute(np.random.permutation(event_size)).astype("int32"))` is not a + reliable parameterization (nor would it be even if using `tf.constant`). A + safe alternative is to use `tf.get_variable` to achieve "init once" behavior, + i.e., + + ```python + def init_once(x, name): + return tf.get_variable(name, initializer=x, trainable=False) + + Permute(permutation=init_once( + np.random.permutation(event_size).astype("int32"), + name="permutation")) + ``` + + """ + + def __init__(self, permutation, validate_args=False, name=None): + """Creates the `Permute` bijector. + + Args: + permutation: An `int`-like vector-shaped `Tensor` representing the + permutation to apply to the rightmost dimension of the transformed + `Tensor`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if `not permutation.dtype.is_integer`. + ValueError: if `permutation` does not contain exactly one of each of + `{0, 1, ..., d}`. + """ + with ops.name_scope(name, "permute", values=[permutation]): + permutation = ops.convert_to_tensor( + permutation, + name="permutation") + if not permutation.dtype.is_integer: + raise TypeError("permutation.dtype ({}) should be `int`-like.".format( + permutation.dtype.name)) + p = tensor_util.constant_value(permutation) + if p is not None: + if set(p) != set(np.arange(p.size)): + raise ValueError("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.") + elif validate_args: + p, _ = nn_ops.top_k(-permutation, + k=array_ops.shape(permutation)[-1], + sorted=True) + permutation = control_flow_ops.with_dependencies([ + check_ops.assert_equal( + -p, math_ops.range(array_ops.size(p)), + message=("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.")), + ], permutation) + self._permutation = permutation + super(Permute, self).__init__( + is_constant_jacobian=True, + validate_args=validate_args, + name=name or "permute") + + @property + def permutation(self): + return self._permutation + + def _forward(self, x): + return array_ops.gather(x, self.permutation, axis=-1) + + def _inverse(self, y): + return array_ops.gather( + y, + array_ops.invert_permutation(self.permutation), + axis=-1) + + def _inverse_log_det_jacobian(self, y): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..8997f7ab6929745275edb38712a5bbb0a9b25ddb --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reshape bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.distributions.python.ops.bijectors.reshape_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["Reshape"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..93682639aa3be3b8f59a369dedb6ee773c468130 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py @@ -0,0 +1,297 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reshape bijectors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ + "Reshape", +] + + +class Reshape(bijector_lib.Bijector): + """Reshapes the `event_shape` of a `Tensor`. + + The semantics generally follow that of `tf.reshape()`, with + a few differences: + * The user must provide both the input and output shape, so that + the transformation can be inverted. + * The `Reshape` bijector automatically broadcasts over the leftmost + dimensions of its input (`sample_shape` and `batch_shape`); only + the rightmost `event_ndims_in` dimensions are reshaped. The + number of dimensions to reshape is inferred from the provided + `event_shape_in` (`event_ndims_in = len(event_shape_in)`). + * The `Reshape` bijector does not currently support + partially-specified shapes, i.e., those with a dimension + implicitly specified by `-1`. + + Example usage: + ```python + + bs = tf.contrib.distributions.bijectors + + reverse = bs.Reshape(event_shape_out=[1,2], + event_shape_in=[2,]) + + reverse.forward([1., 2.]) # shape [2,] + # ==> [[1., 2.]] # shape [1,2] + + reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2] + # ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2] + + reverse.inverse([[1., 2.]]) # shape [1,2] + # ==> [1., 2.] # shape [2,] + + reverse.forward_log_det_jacobian(any_value) + # ==> 0. + + reverse.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + """ + + def __init__(self, event_shape_out, event_shape_in, + validate_args=False, name=None): + """Creates a `Reshape` bijector. + + Args: + event_shape_out: An `int`-like vector-shaped `Tensor` + representing the fully specified (no -1's) event shape of the + transformed output. + event_shape_in: An `int`-like vector-shaped `Tensor` + representing the fully specified (no -1's) event shape of the + input. + validate_args: Python `bool` indicating whether arguments should + be checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if either `event_shape_in` or `event_shape_out` has + non-vector shape (`rank > 1`), or non-integer `dtype`. + ValueError: if either `event_shape_in` or `event_shape_out` + contains non-positive entries, or if their sizes do not match + (`prod(event_shape_in)` != `prod(event_shape_out)`), or if + their dimensionality(s) cannot be statically inferred. + """ + with ops.name_scope(name, "reshape", + values=[event_shape_out, event_shape_in]): + + event_shape_out = ops.convert_to_tensor(event_shape_out, + name="event_shape_out", + preferred_dtype=dtypes.int32) + event_shape_in = ops.convert_to_tensor(event_shape_in, + name="event_shape_in", + preferred_dtype=dtypes.int32) + + # check that input shapes are positive integers + assertions = [] + assertions += self._maybe_check_valid_shape( + event_shape_out, "event_shape_out", + validate_args=validate_args) + assertions += self._maybe_check_valid_shape( + event_shape_in, "event_shape_in", validate_args=validate_args) + + # check that prod(event_shape_in) = prod(event_shape_out) + assertions += self._maybe_check_matching_sizes( + event_shape_in, event_shape_out, validate_args=validate_args) + + self._assertions = assertions + self._event_shape_in = event_shape_in + self._event_shape_out = event_shape_out + self._event_shape_in_static = tensor_util.constant_value_as_shape( + event_shape_in) + self._event_shape_out_static = tensor_util.constant_value_as_shape( + event_shape_out) + + super(Reshape, self).__init__(is_constant_jacobian=True, + validate_args=validate_args, + name=name or "reshape") + + def _maybe_check_valid_shape(self, shape_tensor, label, + validate_args=False): + """Check that a shape Tensor is int-type and positive.""" + + assertions = [] + + if not shape_tensor.dtype.is_integer: + raise TypeError("{} dtype ({}) should be `int`-like.".format( + label, shape_tensor.dtype.name)) + + shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor)) + if shape_rank is not None and shape_rank > 1: + raise ValueError("{} rank should be <= 1.".format(label)) + + s = tensor_util.constant_value(shape_tensor) + if s is not None: + if (s <= 0).any(): + raise ValueError("{} entries must be positive, but found {}".format( + label, s)) + elif validate_args: + assertions.append(check_ops.assert_positive( + shape_tensor, message="{} entries must be positive".format(label))) + + return assertions + + def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out, + validate_args=False): + """Check that prod(event_shape_in)==prod(event_shape_out).""" + + def _get_size_from_shape(shape): + """Computes size from a shape `Tensor`, statically if possible.""" + s = tensor_util.constant_value(shape) + if s is not None: + return [np.int32(np.prod(s))]*2 + return None, math_ops.reduce_prod(shape, name="size") + + # Ensure `event_shape_in` is compatible with `event_shape_out`. + event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking + event_shape_in) + event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking + event_shape_out) + + assertions = [] + if event_size_in_ is not None and event_size_out_ is not None: + if event_size_in_ != event_size_out_: + raise ValueError( + "Input `event_size` ({}) does not match output `event_size` ({}).". + format(event_size_in, event_size_out_)) + elif validate_args: + assertions.append(check_ops.assert_equal( + event_size_in, event_size_out, + message="Input/output `event_size`s do not match.")) + + return assertions + + def _reshape_helper(self, x, event_shape_in, event_shape_out): + """Reshape only the event_shape of an input `Tensor`.""" + + def _get_rank_from_shape(shape): + """Computes rank from a shape `Tensor`, statically if possible.""" + # Uses fact that rank is "shape of shape". + ndims = shape.shape.with_rank_at_least(1)[0].value + if ndims is not None: + return ndims, ndims + return None, array_ops.shape(shape)[0] + + event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in) + + assertions = [] + # Ensure x.event_shape is compatible with event_shape_in. + if x.shape.ndims is not None: + x_ndims_, x_ndims = [x.shape.ndims]*2 + else: + x_ndims_, x_ndims = None, array_ops.rank(x) + + if (event_ndims_in_ is not None + and x_ndims_ is not None + and x.shape.with_rank_at_least(event_ndims_in_)[ + x_ndims_-event_ndims_in_:].is_fully_defined()): + x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking + np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 + else: + x_event_shape_, x_event_shape = ( + None, array_ops.shape(x)[x_ndims-event_ndims_in:]) + + event_shape_in_ = tensor_util.constant_value(event_shape_in) + + if x_event_shape_ is not None and event_shape_in_ is not None: + if not np.equal(x_event_shape_, event_shape_in_).all(): + raise ValueError( + "Input `event_shape` ({}) does not match `event_shape_in` ({}).". + format(x_event_shape_, event_shape_in_)) + elif self.validate_args: + assertions.append(check_ops.assert_equal( + x_event_shape, event_shape_in, + message="Input `event_shape` does not match `event_shape_in`.")) + + if assertions: + x = control_flow_ops.with_dependencies(assertions, x) + + # get the parts of shape(x) that will not change + sample_and_batch_shape = array_ops.shape(x) + + ndims = (x.shape.ndims if x.shape.ndims is not None + else array_ops.rank(x)) + sample_and_batch_shape = sample_and_batch_shape[ + :(ndims - math_ops.abs(event_ndims_in))] + + new_shape = array_ops.concat( + [sample_and_batch_shape, event_shape_out], axis=0) + + return array_ops.reshape(x, new_shape) + + def _forward(self, x): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(x, + self._event_shape_in, + self._event_shape_out) + + def _inverse(self, y): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(y, + self._event_shape_out, + self._event_shape_in) + + def _inverse_log_det_jacobian(self, y): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=x.dtype) + + def _forward_event_shape(self, input_shape): + self._event_shape_in_static.assert_is_compatible_with(input_shape) + return self._event_shape_out_static + + def _inverse_event_shape(self, output_shape): + self._event_shape_out_static.assert_is_compatible_with(output_shape) + return self._event_shape_in_static + + def _forward_event_shape_tensor(self, input_shape): + input_assertions = self._maybe_check_valid_shape( + input_shape, "input event shape", validate_args=self.validate_args) + input_assertions += self._maybe_check_matching_sizes( + input_shape, self._event_shape_out, + validate_args=self.validate_args) + + return control_flow_ops.with_dependencies( + input_assertions + self._assertions, self._event_shape_out) + + def _inverse_event_shape_tensor(self, output_shape): + + output_assertions = self._maybe_check_valid_shape( + output_shape, "output event shape", validate_args=self.validate_args) + output_assertions += self._maybe_check_matching_sizes( + output_shape, self._event_shape_in, validate_args=self.validate_args) + + return control_flow_ops.with_dependencies( + output_assertions + self._assertions, self._event_shape_in) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py index dac3d812eef28b6aed291db051726d0594f7316a..3a75e4ae9495793901b0da91a5aa3982aab35852 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py @@ -89,18 +89,18 @@ class SinhArcsinh(bijector.Bijector): """ def __init__(self, - skewness=0., - tailweight=1., + skewness=None, + tailweight=None, event_ndims=0, validate_args=False, - name="sinh_arcsinh"): + name="SinhArcsinh"): """Instantiates the `SinhArcsinh` bijector. Args: - skewness: Skewness parameter. Float-type `Tensor`. + skewness: Skewness parameter. Float-type `Tensor`. Default is `0` + of type `float32`. tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as - `skewness` - and broadcastable `shape`. + `skewness` and broadcastable `shape`. Default is `1` of type `float32`. event_ndims: Python scalar indicating the number of dimensions associated with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be @@ -111,8 +111,12 @@ class SinhArcsinh(bijector.Bijector): self._name = name self._validate_args = validate_args with self._name_scope("init", values=[skewness, tailweight]): - self._skewness = ops.convert_to_tensor(skewness, name="skewness") - self._tailweight = ops.convert_to_tensor(tailweight, name="tailweight") + tailweight = 1. if tailweight is None else tailweight + skewness = 0. if skewness is None else skewness + self._skewness = ops.convert_to_tensor( + skewness, name="skewness") + self._tailweight = ops.convert_to_tensor( + tailweight, name="tailweight", dtype=self._skewness.dtype) check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) if validate_args: self._tailweight = control_flow_ops.with_dependencies([ diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index f1b7bf468e92913e6d1d5dd965de9c3dc220f9ed..599c855cda434d9249187d5d154d50a8a8c49a6c 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -198,3 +198,19 @@ class ConditionalTransformedDistribution( distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) return self.distribution.survival_function(x, **distribution_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) + def _quantile(self, value, bijector_kwargs=None, distribution_kwargs=None): + if self._is_maybe_event_override: + raise NotImplementedError("quantile is not implemented when overriding " + "event_shape") + if not self.bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError("quantile is not implemented when " + "bijector is not injective.") + bijector_kwargs = bijector_kwargs or {} + distribution_kwargs = distribution_kwargs or {} + # x_q is the "qth quantile" of X iff q = P[X <= x_q]. Now, since X = + # g^{-1}(Y), q = P[X <= x_q] = P[g^{-1}(Y) <= x_q] = P[Y <= g(x_q)], + # implies the qth quantile of Y is g(x_q). + inv_cdf = self.distribution.quantile(value, **distribution_kwargs) + return self.bijector.forward(inv_cdf, **bijector_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 3ed5592bf955e7c94c5b5b4555955ae62a63c35d..869b5698e57d199755ce1686a74a1eafe3b73e7d 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -160,7 +160,7 @@ def make_tril_scale( scale_tril = array_ops.matrix_set_diag(scale_tril, tril_diag) - return linalg.LinearOperatorTriL( + return linalg.LinearOperatorLowerTriangular( tril=_maybe_attach_assertion(scale_tril), is_non_singular=True, is_self_adjoint=False, diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index 393c008242417f8a2bf44eed2d9b2e81800d34c7..6a74ca9a0ae1ad30081d21cc15a65be052a99e2a 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -45,24 +45,24 @@ class Independent(distribution_lib.Distribution): `p(x_1, ..., x_B) = p_1(x_1) * ... * p_B(x_B)` where `p_b(X_b)` is the probability of the `b`-th rv. More generally `B, E` can be arbitrary shapes. - Similarly, the `Independent` distribution specifies a distribution over - `[B, E]`-shaped events. It operates by reinterpreting the rightmost batch dims - as part of the event dimensions. The `reduce_batch_ndims` parameter controls - the number of batch dims which are absorbed as event dims; - `reduce_batch_ndims < len(batch_shape)`. For example, the `log_prob` function - entails a `reduce_sum` over the rightmost `reduce_batch_ndims` after calling - the base distribution's `log_prob`. In other words, since the batch - dimension(s) index independent distributions, the resultant multivariate will - have independent components. + Similarly, the `Independent` distribution specifies a distribution over `[B, + E]`-shaped events. It operates by reinterpreting the rightmost batch dims as + part of the event dimensions. The `reinterpreted_batch_ndims` parameter + controls the number of batch dims which are absorbed as event dims; + `reinterpreted_batch_ndims < len(batch_shape)`. For example, the `log_prob` + function entails a `reduce_sum` over the rightmost `reinterpreted_batch_ndims` + after calling the base distribution's `log_prob`. In other words, since the + batch dimension(s) index independent distributions, the resultant multivariate + will have independent components. #### Mathematical Details The probability function is, ```none - prob(x; reduce_batch_ndims) = tf.reduce_prod( + prob(x; reinterpreted_batch_ndims) = tf.reduce_prod( dist.prob(x), - axis=-1-range(reduce_batch_ndims)) + axis=-1-range(reinterpreted_batch_ndims)) ``` #### Examples @@ -73,7 +73,7 @@ class Independent(distribution_lib.Distribution): # Make independent distribution from a 2-batch Normal. ind = ds.Independent( distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]), - reduce_batch_ndims=1) + reinterpreted_batch_ndims=1) # All batch dims have been "absorbed" into event dims. ind.batch_shape # ==> [] @@ -84,7 +84,7 @@ class Independent(distribution_lib.Distribution): distribution=ds.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]), - reduce_batch_ndims=1) + reinterpreted_batch_ndims=1) # All batch dims have been "absorbed" into event dims. ind.batch_shape # ==> [] @@ -94,14 +94,17 @@ class Independent(distribution_lib.Distribution): """ def __init__( - self, distribution, reduce_batch_ndims=1, validate_args=False, name=None): + self, distribution, reinterpreted_batch_ndims=None, + validate_args=False, name=None): """Construct a `Independent` distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. - reduce_batch_ndims: Scalar, integer number of rightmost batch dims which - will be regard as event dims. + reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims + which will be regarded as event dims. When `None` all but the first + batch axis (batch axis 0) will be transferred to event dimensions + (analogous to `tf.layers.flatten`). validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. @@ -109,19 +112,25 @@ class Independent(distribution_lib.Distribution): Default value: `Independent + distribution.name`. Raises: - ValueError: if `reduce_batch_ndims` exceeds `distribution.batch_ndims` + ValueError: if `reinterpreted_batch_ndims` exceeds + `distribution.batch_ndims` """ parameters = locals() name = name or "Independent" + distribution.name self._distribution = distribution with ops.name_scope(name): - reduce_batch_ndims = ops.convert_to_tensor( - reduce_batch_ndims, dtype=dtypes.int32, name="reduce_batch_ndims") - self._reduce_batch_ndims = reduce_batch_ndims - self._static_reduce_batch_ndims = tensor_util.constant_value( - reduce_batch_ndims) - if self._static_reduce_batch_ndims is not None: - self._reduce_batch_ndims = self._static_reduce_batch_ndims + if reinterpreted_batch_ndims is None: + reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims( + distribution) + reinterpreted_batch_ndims = ops.convert_to_tensor( + reinterpreted_batch_ndims, + dtype=dtypes.int32, + name="reinterpreted_batch_ndims") + self._reinterpreted_batch_ndims = reinterpreted_batch_ndims + self._static_reinterpreted_batch_ndims = tensor_util.constant_value( + reinterpreted_batch_ndims) + if self._static_reinterpreted_batch_ndims is not None: + self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims super(Independent, self).__init__( dtype=self._distribution.dtype, reparameterization_type=self._distribution.reparameterization_type, @@ -129,19 +138,19 @@ class Independent(distribution_lib.Distribution): allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, graph_parents=( - [reduce_batch_ndims] + + [reinterpreted_batch_ndims] + distribution._graph_parents), # pylint: disable=protected-access name=name) self._runtime_assertions = self._make_runtime_assertions( - distribution, reduce_batch_ndims, validate_args) + distribution, reinterpreted_batch_ndims, validate_args) @property def distribution(self): return self._distribution @property - def reduce_batch_ndims(self): - return self._reduce_batch_ndims + def reinterpreted_batch_ndims(self): + return self._reinterpreted_batch_ndims def _batch_shape_tensor(self): with ops.control_dependencies(self._runtime_assertions): @@ -149,13 +158,14 @@ class Independent(distribution_lib.Distribution): batch_ndims = (batch_shape.shape[0].value if batch_shape.shape.with_rank_at_least(1)[0].value else array_ops.shape(batch_shape)[0]) - return batch_shape[:batch_ndims - self.reduce_batch_ndims] + return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims] def _batch_shape(self): batch_shape = self.distribution.batch_shape - if self._static_reduce_batch_ndims is None or batch_shape.ndims is None: + if (self._static_reinterpreted_batch_ndims is None + or batch_shape.ndims is None): return tensor_shape.TensorShape(None) - d = batch_shape.ndims - self._static_reduce_batch_ndims + d = batch_shape.ndims - self._static_reinterpreted_batch_ndims return batch_shape[:d] def _event_shape_tensor(self): @@ -165,15 +175,16 @@ class Independent(distribution_lib.Distribution): if batch_shape.shape.with_rank_at_least(1)[0].value else array_ops.shape(batch_shape)[0]) return array_ops.concat([ - batch_shape[batch_ndims - self.reduce_batch_ndims:], + batch_shape[batch_ndims - self.reinterpreted_batch_ndims:], self.distribution.event_shape_tensor(), ], axis=0) def _event_shape(self): batch_shape = self.distribution.batch_shape - if self._static_reduce_batch_ndims is None or batch_shape.ndims is None: + if (self._static_reinterpreted_batch_ndims is None + or batch_shape.ndims is None): return tensor_shape.TensorShape(None) - d = batch_shape.ndims - self._static_reduce_batch_ndims + d = batch_shape.ndims - self._static_reinterpreted_batch_ndims return batch_shape[d:].concatenate(self.distribution.event_shape) def _sample_n(self, n, seed): @@ -205,15 +216,16 @@ class Independent(distribution_lib.Distribution): return self.distribution.mode() def _make_runtime_assertions( - self, distribution, reduce_batch_ndims, validate_args): + self, distribution, reinterpreted_batch_ndims, validate_args): assertions = [] - static_reduce_batch_ndims = tensor_util.constant_value(reduce_batch_ndims) + static_reinterpreted_batch_ndims = tensor_util.constant_value( + reinterpreted_batch_ndims) batch_ndims = distribution.batch_shape.ndims - if batch_ndims is not None and static_reduce_batch_ndims is not None: - if static_reduce_batch_ndims > batch_ndims: - raise ValueError("reduce_batch_ndims({}) cannot exceed " + if batch_ndims is not None and static_reinterpreted_batch_ndims is not None: + if static_reinterpreted_batch_ndims > batch_ndims: + raise ValueError("reinterpreted_batch_ndims({}) cannot exceed " "distribution.batch_ndims({})".format( - static_reduce_batch_ndims, batch_ndims)) + static_reinterpreted_batch_ndims, batch_ndims)) elif validate_args: batch_shape = distribution.batch_shape_tensor() batch_ndims = ( @@ -221,13 +233,24 @@ class Independent(distribution_lib.Distribution): if batch_shape.shape.with_rank_at_least(1)[0].value is not None else array_ops.shape(batch_shape)[0]) assertions.append(check_ops.assert_less_equal( - reduce_batch_ndims, batch_ndims, - message="reduce_batch_ndims cannot exceed distribution.batch_ndims")) + reinterpreted_batch_ndims, batch_ndims, + message=("reinterpreted_batch_ndims cannot exceed " + "distribution.batch_ndims"))) return assertions def _reduce_sum(self, stat): - if self._static_reduce_batch_ndims is None: - range_ = array_ops.range(self._reduce_batch_ndims) + if self._static_reinterpreted_batch_ndims is None: + range_ = math_ops.range(self._reinterpreted_batch_ndims) else: - range_ = np.arange(self._static_reduce_batch_ndims) + range_ = np.arange(self._static_reinterpreted_batch_ndims) return math_ops.reduce_sum(stat, axis=-1-range_) + + def _get_default_reinterpreted_batch_ndims(self, distribution): + """Computes the default value for reinterpreted_batch_ndim __init__ arg.""" + ndims = distribution.batch_shape.ndims + if ndims is None: + which_maximum = math_ops.maximum + ndims = array_ops.shape(distribution.batch_shape_tensor())[0] + else: + which_maximum = np.maximum + return which_maximum(0, ndims - 1) diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index e92bcf8c1fda3e550616abada2879e355866c055..5558ef0f255db684b229d129666634e50c625887 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -260,6 +260,14 @@ class MixtureSameFamily(distribution.Distribution): probs * self.components_distribution.mean(), axis=-1 - self._event_ndims) # [B, E] + def _log_cdf(self, x): + x = self._pad_sample_dims(x) + log_cdf_x = self.components_distribution.log_cdf(x) # [S, B, k] + log_mix_prob = nn_ops.log_softmax( + self.mixture_distribution.logits, dim=-1) # [B, k] + return math_ops.reduce_logsumexp( + log_cdf_x + log_mix_prob, axis=-1) # [S, B] + def _variance(self): with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index ee3e02e0203a3338b7e6a40b7e3ff30c0a0940f0..040bc230722194316b8a74627344e315a2578281 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -237,7 +237,7 @@ class MultivariateNormalDiagPlusLowRank( scale_perturb_diag, name="scale_perturb_diag") if has_low_rank: - scale = linalg.LinearOperatorUDVHUpdate( + scale = linalg.LinearOperatorLowRankUpdate( scale, u=scale_perturb_factor, diag_update=scale_perturb_diag, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 221eed547bacd59d3c0d065f386fe45970f9bae9..f9952b2069d6dfd2593e6bd71ede0badf44cdf98 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -174,8 +174,8 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): covariance_matrix = control_flow_ops.with_dependencies( [assert_symmetric], covariance_matrix) # No need to validate that covariance_matrix is non-singular. - # LinearOperatorTriL has an assert_non_singular method that is called - # by the Bijector. + # LinearOperatorLowerTriangular has an assert_non_singular method that + # is called by the Bijector. # However, cholesky() ignores the upper triangular part, so we do need # to separately assert symmetric. scale_tril = linalg_ops.cholesky(covariance_matrix) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 50c7ba418be5b66127a3fde9f02a39b8f52ff841..300bdd5f6064a1cc9c336689ac4fae04338edb30 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -18,16 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.bijectors import AffineLinearOperator from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.linalg import linalg __all__ = [ @@ -92,7 +91,7 @@ class MultivariateNormalLinearOperator( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -106,7 +105,7 @@ class MultivariateNormalLinearOperator( mvn = ds.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorTriL(scale)) + scale=la.LinearOperatorLowerTriangular(scale)) # Covariance agrees with cholesky(cov) parameterization. mvn.covariance().eval() @@ -243,8 +242,8 @@ class MultivariateNormalLinearOperator( def _variance(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.square(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense())) else: @@ -254,8 +253,8 @@ class MultivariateNormalLinearOperator( def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.abs(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return math_ops.sqrt(array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense()))) else: @@ -299,7 +298,10 @@ def _kl_brute_force(a, b, name=None): def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" # http://mathworld.wolfram.com/FrobeniusNorm.html - return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1])) + # The gradient of KL[p,q] is not defined when p==q. The culprit is + # linalg_ops.norm, i.e., we cannot use the commented out code. + # return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1])) + return math_ops.reduce_sum(math_ops.square(x), axis=[-2, -1]) # TODO(b/35041439): See also b/35040945. Remove this function once LinOp # supports something like: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index 48c4dddc8133d408e1beb7a8aef2abd676895fe3..260dcc18f513d5440d3d39368539274c03faa72a 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -121,6 +121,14 @@ class MultivariateNormalTriL( [-10, 0, 9]] # shape: [2, 3] mvn.prob(x).eval() # shape: [2] + # Instantiate a "learnable" MVN. + dims = 4 + with tf.variable_scope("model"): + mvn = ds.MultivariateNormalTriL( + loc=tf.get_variable(shape=[dims], dtype=tf.float32, name="mu"), + scale_tril=ds.fill_triangular( + tf.get_variable(shape=[dims * (dims + 1) / 2], + dtype=tf.float32, name="chol_Sigma"))) ``` """ @@ -188,9 +196,9 @@ class MultivariateNormalTriL( assert_proper_shapes=validate_args) else: # No need to validate that scale_tril is non-singular. - # LinearOperatorTriL has an assert_non_singular method that is called - # by the Bijector. - scale = linalg.LinearOperatorTriL( + # LinearOperatorLowerTriangular has an assert_non_singular + # method that is called by the Bijector. + scale = linalg.LinearOperatorLowerTriangular( scale_tril, is_non_singular=True, is_self_adjoint=False, diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py index c8c396f6f80cf7f3228a75d279fff91ae15813ad..3a58df80da6c02b056f5e5a63bf41de5fc6d44a4 100644 --- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py +++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py @@ -167,8 +167,8 @@ class NegativeBinomial(distribution.Distribution): def _log_unnormalized_prob(self, x): if self.validate_args: x = distribution_util.embed_check_nonnegative_integer_form(x) - return (self.total_count * math_ops.log1p(-self.probs) - + x * math_ops.log(self.probs)) + return (self.total_count * math_ops.log_sigmoid(-self.logits) + + x * math_ops.log_sigmoid(self.logits)) def _log_normalization(self, x): if self.validate_args: diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 65ee3a16d624822dd69f9dea1507b96703db12be..8a95038a3c8eccf8a75fea79d0a62f9883b4f13a 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -29,7 +30,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib -from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -55,8 +55,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): ``` where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms - are from [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). Note that + are from [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). Note that the second line made the substitution: `z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above] and `dl = sqrt(2) scale lambda(z) dz` @@ -65,8 +67,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): Poisson rate parameter. Unfortunately, the non-approximate distribution lacks an analytical probability density function (pdf). Therefore the `PoissonLogNormalQuadratureCompound` class implements an approximation based - on [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). + on [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). + Note: although the `PoissonLogNormalQuadratureCompound` is approximately the Poisson-LogNormal compound distribution, it is itself a valid distribution. Viz., it possesses a `sample`, `log_prob`, `mean`, `variance`, etc. which are @@ -76,9 +81,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): The `PoissonLogNormalQuadratureCompound` approximates a Poisson-LogNormal [compound distribution]( - https://en.wikipedia.org/wiki/Compound_probability_distribution). - Using variable-substitution and [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can + https://en.wikipedia.org/wiki/Compound_probability_distribution). Using + variable-substitution and [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can redefine the distribution to be a parameter-less convex combination of `deg` different Poisson samples. @@ -93,7 +100,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): : d=0, ..., deg-1 } ``` - where, [`grid, w = numpy.polynomial.hermite.hermgauss(deg)`]( + where, [e.g., `grid, w = numpy.polynomial.hermite.hermgauss(deg)`]( https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.polynomial.hermite.hermgauss.html) and `prob = w / sqrt(pi)`. @@ -106,14 +113,15 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): pln = ds.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) """ def __init__(self, loc, scale, - quadrature_polynomial_degree=8, + quadrature_grid_and_probs=None, validate_args=False, allow_nan_stats=True, name="PoissonLogNormalQuadratureCompound"): @@ -124,8 +132,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. - quadrature_polynomial_degree: Python `int`-like scalar. - Default value: 8. + quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s + representing the sample points and the corresponding (possibly + normalized) weight. When `None`, defaults to: + `np.polynomial.hermite.hermgauss(deg=8)`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -153,18 +163,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): "loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format( loc.dtype.name, scale.dtype.name)) - self._degree = quadrature_polynomial_degree - - grid, prob = np.polynomial.hermite.hermgauss( - deg=quadrature_polynomial_degree) - - # It should be that `sum(prob) == sqrt(pi)`, but self-normalization is - # more numerically stable. - prob = prob.astype(dtype.as_numpy_dtype) - prob /= np.linalg.norm(prob, ord=1) + grid, probs = distribution_util.process_quadrature_grid_and_probs( + quadrature_grid_and_probs, dtype, validate_args) + self._quadrature_grid = grid + self._quadrature_probs = probs + self._quadrature_size = distribution_util.dimension_size(probs, axis=0) self._mixture_distribution = categorical_lib.Categorical( - logits=np.log(prob), + logits=math_ops.log(self._quadrature_probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) @@ -210,9 +216,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): return self._scale @property - def quadrature_polynomial_degree(self): - """Polynomial largest exponent used for Gauss-Hermite quadrature.""" - return self._degree + def quadrature_grid(self): + """Quadrature grid points.""" + return self._quadrature_grid + + @property + def quadrature_probs(self): + """Quadrature normalized weights.""" + return self._quadrature_probs def _batch_shape_tensor(self): return array_ops.broadcast_dynamic_shape( @@ -242,10 +253,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): [batch_size])), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) - # Stride `quadrature_polynomial_degree` for `batch_size` number of times. + # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, - limit=batch_size * self._degree, - delta=self._degree, + limit=batch_size * self._quadrature_size, + delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = array_ops.gather( diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 699cf45a73883a49d116fa70c81a4f9ecb36e598..b6becfa9fc93f189a1a7bf7b2a7af8dc1f2e9720 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -130,7 +130,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): temperature, logits=None, probs=None, - dtype=dtypes.float32, + dtype=None, validate_args=False, allow_nan_stats=True, name="ExpRelaxedOneHotCategorical"): @@ -150,7 +150,8 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. - dtype: The type of the event samples (default: float32). + dtype: The type of the event samples (default: inferred from + logits/probs). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -163,14 +164,21 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): """ parameters = locals() with ops.name_scope(name, values=[logits, probs, temperature]): + + self._logits, self._probs = distribution_util.get_logits_and_probs( + name=name, logits=logits, probs=probs, validate_args=validate_args, + multidimensional=True) + + if dtype is None: + dtype = self._logits.dtype + if not validate_args: + temperature = math_ops.cast(temperature, dtype) + with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): self._temperature = array_ops.identity(temperature, name="temperature") self._temperature_2d = array_ops.reshape(temperature, [-1, 1], name="temperature_2d") - self._logits, self._probs = distribution_util.get_logits_and_probs( - name=name, logits=logits, probs=probs, validate_args=validate_args, - multidimensional=True) logits_shape_static = self._logits.get_shape().with_rank_at_least(1) if logits_shape_static.ndims is not None: @@ -230,7 +238,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): def _sample_n(self, n, seed=None): sample_shape = array_ops.concat([[n], array_ops.shape(self.logits)], 0) - logits = self.logits * array_ops.ones(sample_shape) + logits = self.logits * array_ops.ones(sample_shape, dtype=self.dtype) logits_2d = array_ops.reshape(logits, [-1, self.event_size]) # Uniform variates must be sampled from the open-interval `(0, 1)` rather # than `[0, 1)`. To do so, we use `np.finfo(self.dtype.as_numpy_dtype).tiny` @@ -368,7 +376,7 @@ class RelaxedOneHotCategorical( temperature, logits=None, probs=None, - dtype=dtypes.float32, + dtype=None, validate_args=False, allow_nan_stats=True, name="RelaxedOneHotCategorical"): @@ -388,7 +396,8 @@ class RelaxedOneHotCategorical( dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. - dtype: The type of the event samples (default: float32). + dtype: The type of the event samples (default: inferred from + logits/probs). validate_args: Unused in this distribution. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index cdf81526da57a864491ff9e97b474e7722f5516e..b05f15771a3a94779ffddea8f16ad2fa4ea2fdd1 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -51,8 +51,9 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): `(loc, scale, skewness, tailweight)`, via the relation: ``` - Y := loc + scale * F(Z) * (2 / F(2)) + Y := loc + scale * F(Z) * (2 / F_0(2)) F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) + F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) ``` This distribution is similar to the location-scale transformation @@ -61,7 +62,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): * If `skewness = 0` and `tailweight = 1` (the defaults), `F(Z) = Z`, and then `Y = L(Z)` exactly. * `loc` is used in both to shift the result by a constant factor. - * Our definition of `C` ensures that + * The multiplication of `scale` by `2 / F_0(2)` ensures that if `skewness = 0` `P[Y - loc <= 2 * scale] = P[L(Z) - loc <= 2 * scale]`. Thus it can be said that the weights in the tails of `Y` and `L(Z)` beyond `loc + 2 * scale` are the same. @@ -84,12 +85,12 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): `|Z| >> (|skewness| * tailweight)**tailweight`, we have `Y approx 0.5 Z**tailweight e**(sign(Z) skewness * tailweight)`. - To see the argument about `C` and quantiles, note that + To see the argument regarding multiplying `scale` by `2 / F_0(2)`, ``` - P[(Y - loc) / scale <= 2] = P[F(Z) <= 2 * scale / C] - = P[Z <= F^{-1}(2 * scale / C)] - = P[Z <= 2]. + P[(Y - loc) / scale <= 2] = P[F(Z) * (2 / F_0(2)) <= 2] + = P[F(Z) <= F_0(2)] + = P[Z <= 2] (if F = F_0). ``` """ @@ -101,7 +102,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): distribution=None, validate_args=False, allow_nan_stats=True, - name="MultivariateNormalLinearOperator"): + name="SinhArcsinh"): """Construct SinhArcsinh distribution on `(-inf, inf)`. Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape @@ -138,6 +139,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): dtype = loc.dtype scale = ops.convert_to_tensor(scale, name="scale", dtype=dtype) tailweight = 1. if tailweight is None else tailweight + has_default_skewness = skewness is None skewness = 0. if skewness is None else skewness tailweight = ops.convert_to_tensor( tailweight, name="tailweight", dtype=dtype) @@ -149,7 +151,8 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): # Recall, with Z a random variable, # Y := loc + C * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) - # C := 2 * scale / F(2) + # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) + # C := 2 * scale / F_0(2) if distribution is None: distribution = normal.Normal( loc=array_ops.zeros([], dtype=dtype), @@ -164,9 +167,15 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): # Make the SAS bijector, 'F'. f = bijectors.SinhArcsinh( skewness=skewness, tailweight=tailweight, event_ndims=0) + if has_default_skewness: + f_noskew = f + else: + f_noskew = bijectors.SinhArcsinh( + skewness=skewness.dtype.as_numpy_dtype(0.), + tailweight=tailweight, event_ndims=0) - # Make the Affine bijector, Z --> loc + C * Z. - c = 2 * scale / f.forward(ops.convert_to_tensor(2, dtype=dtype)) + # Make the Affine bijector, Z --> loc + scale * Z (2 / F_0(2)) + c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype)) affine = bijectors.Affine( shift=loc, scale_identity_multiplier=c, diff --git a/tensorflow/contrib/distributions/python/ops/test_util.py b/tensorflow/contrib/distributions/python/ops/test_util.py index da7d3907acb6ac1c6c01ff739aa19fcb95fbb53d..77f2a39273dc365a4ac202d846dd2bc364655c86 100644 --- a/tensorflow/contrib/distributions/python/ops/test_util.py +++ b/tensorflow/contrib/distributions/python/ops/test_util.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import histogram_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables as variables_ops __all__ = [ @@ -37,7 +38,7 @@ class DiscreteScalarDistributionTestHelpers(object): """DiscreteScalarDistributionTestHelpers.""" def run_test_sample_consistent_log_prob( - self, sess, dist, + self, sess_run_fn, dist, num_samples=int(1e5), num_threshold=int(1e3), seed=42, rtol=1e-2, atol=0.): """Tests that sample/log_prob are consistent with each other. @@ -50,7 +51,9 @@ class DiscreteScalarDistributionTestHelpers(object): are consistent. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. num_samples: Python `int` scalar indicating the number of Monte-Carlo @@ -86,7 +89,7 @@ class DiscreteScalarDistributionTestHelpers(object): probs = math_ops.exp(dist.log_prob(edges)) probs = array_ops.reshape(probs, shape=[-1, batch_size])[:, b] - [counts_, probs_] = sess.run([counts, probs]) + [counts_, probs_] = sess_run_fn([counts, probs]) valid = counts_ > num_threshold probs_ = probs_[valid] counts_ = counts_[valid] @@ -94,7 +97,7 @@ class DiscreteScalarDistributionTestHelpers(object): rtol=rtol, atol=atol) def run_test_sample_consistent_mean_variance( - self, sess, dist, + self, sess_run_fn, dist, num_samples=int(1e5), seed=24, rtol=1e-2, atol=0.): """Tests that sample/mean/variance are consistent with each other. @@ -103,7 +106,9 @@ class DiscreteScalarDistributionTestHelpers(object): to the same distribution. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. num_samples: Python `int` scalar indicating the number of Monte-Carlo @@ -129,7 +134,7 @@ class DiscreteScalarDistributionTestHelpers(object): mean_, variance_, stddev_ - ] = sess.run([ + ] = sess_run_fn([ sample_mean, sample_variance, sample_stddev, @@ -186,7 +191,7 @@ class VectorDistributionTestHelpers(object): def run_test_sample_consistent_log_prob( self, - sess, + sess_run_fn, dist, num_samples=int(1e5), radius=1., @@ -239,7 +244,9 @@ class VectorDistributionTestHelpers(object): https://en.wikipedia.org/wiki/Importance_sampling. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The distribution must have non-zero probability of sampling every point @@ -279,33 +286,39 @@ class VectorDistributionTestHelpers(object): def monte_carlo_hypersphere_volume(dist, num_samples, radius, center): # https://en.wikipedia.org/wiki/Importance_sampling x = dist.sample(num_samples, seed=seed) + x = array_ops.identity(x) # Invalidate bijector cacheing. return math_ops.reduce_mean( math_ops.exp(-dist.log_prob(x)) * is_in_ball(x, radius, center), axis=0) - [ - batch_shape_, - actual_volume_, - sample_volume_, - ] = sess.run([ - dist.batch_shape_tensor(), - actual_hypersphere_volume( - dims=dist.event_shape_tensor()[0], - radius=radius), - monte_carlo_hypersphere_volume( - dist, - num_samples=num_samples, - radius=radius, - center=center), - ]) - + # Build graph. + with ops.name_scope( + "run_test_sample_consistent_log_prob", + values=[num_samples, radius, center] + dist._graph_parents): # pylint: disable=protected-access + batch_shape = dist.batch_shape_tensor() + actual_volume = actual_hypersphere_volume( + dims=dist.event_shape_tensor()[0], + radius=radius) + sample_volume = monte_carlo_hypersphere_volume( + dist, + num_samples=num_samples, + radius=radius, + center=center) + init_op = variables_ops.global_variables_initializer() + + # Execute graph. + sess_run_fn(init_op) + [batch_shape_, actual_volume_, sample_volume_] = sess_run_fn([ + batch_shape, actual_volume, sample_volume]) + + # Check results. self.assertAllClose(np.tile(actual_volume_, reps=batch_shape_), sample_volume_, rtol=rtol, atol=atol) def run_test_sample_consistent_mean_covariance( self, - sess, + sess_run_fn, dist, num_samples=int(1e5), seed=24, @@ -319,7 +332,9 @@ class VectorDistributionTestHelpers(object): to the same distribution. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. num_samples: Python `int` scalar indicating the number of Monte-Carlo @@ -353,7 +368,7 @@ class VectorDistributionTestHelpers(object): covariance_, variance_, stddev_ - ] = sess.run([ + ] = sess_run_fn([ sample_mean, sample_covariance, sample_variance, diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 6d297ea1f11398bb6abcb73aef2fce15bf7b429f..92043d6a08833888c36009261addca0d14949ea8 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -23,10 +23,6 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator from tensorflow.contrib.linalg.python.ops import linear_operator_addition as linop_add_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_diag as linop_diag_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_full_matrix as linop_full_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_identity as linop_identity_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_tril as linop_tril_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -37,6 +33,10 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib +from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib +from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib +from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib static_value = distribution_util.static_value @@ -73,8 +73,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): denotes matrix multiplication. However, the non-approximate distribution does not have an analytical probability density function (pdf). Therefore the `VectorDiffeomixture` class implements an approximation based on - [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). I.e., in + [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in Note: although the `VectorDiffeomixture` is approximately the `SoftmaxNormal-Distribution` compound distribution, it is itself a valid distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which @@ -109,8 +111,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior") [compound distribution]( https://en.wikipedia.org/wiki/Compound_probability_distribution). - Using variable-substitution and [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can + Using variable-substitution and [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can redefine the distribution to be a parameter-less convex combination of `K` different affine combinations of a `d` iid samples from `distribution`. @@ -141,7 +145,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): and, ```none - grid, weight = np.polynomial.hermite.hermgauss(quadrature_polynomial_degree) + grid, weight = np.polynomial.hermite.hermgauss(quadrature_size) prob[k] = weight[k] / sqrt(pi) lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i]) ``` @@ -185,7 +189,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and # another with mix_loc=[1]. In both cases, `K=2` and the affine @@ -219,7 +223,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): distribution, loc=None, scale=None, - quadrature_polynomial_degree=8, + quadrature_grid_and_probs=None, validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): @@ -248,7 +252,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices - quadrature_polynomial_degree: Python `int`-like scalar. + quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s + representing the sample points and the corresponding (possibly + normalized) weight. When `None`, defaults to: + `np.polynomial.hermite.hermgauss(deg=8)`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -262,7 +269,8 @@ class VectorDiffeomixture(distribution_lib.Distribution): Raises: ValueError: if `not scale or len(scale) < 2`. ValueError: if `len(loc) != len(scale)` - ValueError: if `quadrature_polynomial_degree < 1`. + ValueError: if `quadrature_grid_and_probs is not None` and + `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` ValueError: if `validate_args` and any not scale.is_positive_definite. TypeError: if any scale.dtype != scale[0].dtype. TypeError: if any loc.dtype != scale[0].dtype. @@ -307,12 +315,6 @@ class VectorDiffeomixture(distribution_lib.Distribution): name="endpoint_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip(loc, scale))] - if quadrature_polynomial_degree < 1: - raise ValueError("quadrature_polynomial_degree={} " - "is not at least 1".format( - quadrature_polynomial_degree)) - self._degree = quadrature_polynomial_degree - # TODO(jvdillon): Remove once we support k-mixtures. # We make this assertion here because otherwise `grid` would need to be a # vector not a scalar. @@ -320,17 +322,17 @@ class VectorDiffeomixture(distribution_lib.Distribution): raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) - grid, prob = np.polynomial.hermite.hermgauss( - deg=quadrature_polynomial_degree) - grid = grid.astype(dtype.as_numpy_dtype) - prob = prob.astype(dtype.as_numpy_dtype) - prob /= np.linalg.norm(prob, ord=1) + grid, probs = distribution_util.process_quadrature_grid_and_probs( + quadrature_grid_and_probs, dtype, validate_args) + self._quadrature_grid = grid + self._quadrature_probs = probs + self._quadrature_size = distribution_util.dimension_size(probs, axis=0) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to # `math_ops.log(self.mixture_distribution.probs)`. self._mixture_distribution = categorical_lib.Categorical( - logits=np.log(prob), + logits=math_ops.log(probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) @@ -357,10 +359,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip( - interpolate_loc(quadrature_polynomial_degree, + interpolate_loc(self._quadrature_size, self._interpolate_weight, loc), - interpolate_scale(quadrature_polynomial_degree, + interpolate_scale(self._quadrature_size, self._interpolate_weight, scale)))] @@ -416,9 +418,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): return self._interpolated_affine @property - def quadrature_polynomial_degree(self): - """Polynomial largest exponent used for Gauss-Hermite quadrature.""" - return self._degree + def quadrature_grid(self): + """Quadrature grid points.""" + return self._quadrature_grid + + @property + def quadrature_probs(self): + """Quadrature normalized weights.""" + return self._quadrature_probs def _batch_shape_tensor(self): return self._batch_shape_ @@ -454,10 +461,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) - # Stride `self._degree` for `batch_size` number of times. + # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, - limit=batch_size * self._degree, - delta=self._degree, + limit=batch_size * self._quadrature_size, + delta=self._quadrature_size, dtype=ids.dtype) weight = array_ops.gather( @@ -772,8 +779,8 @@ def linop_scale(w, op): is_non_singular=op.is_non_singular, is_self_adjoint=op.is_self_adjoint, is_positive_definite=op.is_positive_definite) - if isinstance(op, linop_tril_lib.LinearOperatorTriL): - return linop_tril_lib.LinearOperatorTriL( + if isinstance(op, linop_tril_lib.LinearOperatorLowerTriangular): + return linop_tril_lib.LinearOperatorLowerTriangular( tril=w[..., array_ops.newaxis, array_ops.newaxis] * op.to_dense(), is_non_singular=op.is_non_singular, is_self_adjoint=op.is_self_adjoint, diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index c88572e17fa43ac11778bdddc02484d284b6eb36..356d78b67a8107750f68f7f84d73d1231f5b2b03 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -90,7 +90,7 @@ class VectorExponentialDiag( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index 7123165417ea010fa9da5263e429734d34df3dbd..b313a851b381e5b3a057fd17e6c2ef4eb0fc34f1 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.python.framework import ops @@ -26,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import exponential from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.linalg import linalg __all__ = ["VectorExponentialLinearOperator"] @@ -108,7 +108,7 @@ class VectorExponentialLinearOperator( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. @@ -247,7 +247,7 @@ class VectorExponentialLinearOperator( def _variance(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.square(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) and + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and self.scale.is_self_adjoint): return array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense())) @@ -258,7 +258,7 @@ class VectorExponentialLinearOperator( def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.abs(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) and + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and self.scale.is_self_adjoint): return math_ops.sqrt( array_ops.matrix_diag_part(self.scale.matmul(self.scale.to_dense()))) diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index fdee57695e4e598929396ee4c9fe9f8014ea0f8b..c7abdbb4caf9bee4cbd5991eb5d652f20dd0f8d1 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.python.framework import ops @@ -28,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import laplace from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.linalg import linalg __all__ = [ @@ -110,7 +110,7 @@ class VectorLaplaceLinearOperator( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 3-variate VectorLaplace with some desired covariance. mu = [1., 2, 3] @@ -126,7 +126,7 @@ class VectorLaplaceLinearOperator( # Divide scale by sqrt(2) so that the final covariance will be what we want. vla = ds.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorTriL(scale / tf.sqrt(2))) + scale=la.LinearOperatorLowerTriangular(scale / tf.sqrt(2))) # Covariance agrees with cholesky(cov) parameterization. vla.covariance().eval() @@ -271,8 +271,8 @@ class VectorLaplaceLinearOperator( def _variance(self): if distribution_util.is_diagonal_scale(self.scale): return 2. * math_ops.square(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return array_ops.matrix_diag_part( 2. * self.scale.matmul(self.scale.to_dense())) else: @@ -282,8 +282,8 @@ class VectorLaplaceLinearOperator( def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): return np.sqrt(2) * math_ops.abs(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return np.sqrt(2) * math_ops.sqrt(array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense()))) else: diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 488724e80ce3c4e46dc48263dc17c1bd91b5887a..544a8710709a0afb56c6ae6f36d35de892e8e420 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""SinhArcsinh transformation of a distribution.""" +"""Multi-dimensional (Vector) SinhArcsinh transformation of a distribution.""" from __future__ import absolute_import from __future__ import division @@ -52,8 +52,9 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): matrix multiplication): ``` - Y := loc + scale @ F(Z) * (2 / F(2)) + Y := loc + scale @ F(Z) * (2 / F_0(2)) F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) + F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) ``` This distribution is similar to the location-scale transformation @@ -62,7 +63,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): * If `skewness = 0` and `tailweight = 1` (the defaults), `F(Z) = Z`, and then `Y = L(Z)` exactly. * `loc` is used in both to shift the result by a constant factor. - * Our definition of `C` ensures that + * The multiplication of `scale` by `2 / F_0(2)` ensures that if `skewness = 0` `P[Y - loc <= 2 * scale] = P[L(Z) - loc <= 2 * scale]`. Thus it can be said that the weights in the tails of `Y` and `L(Z)` beyond `loc + 2 * scale` are the same. @@ -85,12 +86,12 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): `|Z| >> (|skewness| * tailweight)**tailweight`, we have `Y approx 0.5 Z**tailweight e**(sign(Z) skewness * tailweight)`. - To see the argument about `C` and quantiles, note that + To see the argument regarding multiplying `scale` by `2 / F_0(2)`, ``` - P[(Y - loc) / scale <= 2] = P[F(Z) <= 2 * scale / C] - = P[Z <= F^{-1}(2 * scale / C)] - = P[Z <= 2]. + P[(Y - loc) / scale <= 2] = P[F(Z) * (2 / F_0(2)) <= 2] + = P[F(Z) <= F_0(2)] + = P[Z <= 2] (if F = F_0). ``` """ @@ -171,12 +172,14 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): ]): loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc tailweight = 1. if tailweight is None else tailweight + has_default_skewness = skewness is None skewness = 0. if skewness is None else skewness # Recall, with Z a random variable, # Y := loc + C * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) - # C := 2 * scale / F(2) + # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) + # C := 2 * scale / F_0(2) # Construct shapes and 'scale' out of the scale_* and loc kwargs. # scale_linop is only an intermediary to: @@ -213,9 +216,16 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): tailweight, dtype=dtype, name="tailweight") f = bijectors.SinhArcsinh( skewness=skewness, tailweight=tailweight, event_ndims=1) + if has_default_skewness: + f_noskew = f + else: + f_noskew = bijectors.SinhArcsinh( + skewness=skewness.dtype.as_numpy_dtype(0.), + tailweight=tailweight, event_ndims=0) # Make the Affine bijector, Z --> loc + C * Z. - c = 2 * scale_diag_part / f.forward(ops.convert_to_tensor(2, dtype=dtype)) + c = 2 * scale_diag_part / f_noskew.forward( + ops.convert_to_tensor(2, dtype=dtype)) affine = bijectors.Affine( shift=loc, scale_diag=c, validate_args=validate_args, event_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index 9d30ce67197ebdeefc69d9b9979fdad4797bb183..e4ac65012b9c7e3ed5ada3ed75020f3905740156 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -251,8 +251,8 @@ class _WishartLinearOperator(distribution.Distribution): # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for LinearOperatorDiag, each matmul is O(k**2), so - # this complexity is O(nbk**2). For LinearOperatorTriL, each matmul is - # O(k^3) so this step has complexity O(nbk^3). + # this complexity is O(nbk**2). For LinearOperatorLowerTriangular, + # each matmul is O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator.matmul(x) # Undo make batch-op ready. @@ -307,8 +307,8 @@ class _WishartLinearOperator(distribution.Distribution): # Complexity: O(nbM*k) where M is the complexity of the operator solving # a vector system. E.g., for LinearOperatorDiag, each solve is O(k), so - # this complexity is O(nbk**2). For LinearOperatorTriL, each solve is - # O(k**2) so this step has complexity O(nbk^3). + # this complexity is O(nbk**2). For LinearOperatorLowerTriangular, + # each solve is O(k**2) so this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self.scale_operator.solve( scale_sqrt_inv_x_sqrt) @@ -544,7 +544,7 @@ class WishartCholesky(_WishartLinearOperator): super(WishartCholesky, self).__init__( df=df, - scale_operator=linalg.LinearOperatorTriL( + scale_operator=linalg.LinearOperatorLowerTriangular( tril=scale, is_non_singular=True, is_positive_definite=True, @@ -655,7 +655,7 @@ class WishartFull(_WishartLinearOperator): ] if validate_args else [], chol) super(WishartFull, self).__init__( df=df, - scale_operator=linalg.LinearOperatorTriL( + scale_operator=linalg.LinearOperatorLowerTriangular( tril=chol, is_non_singular=True, is_positive_definite=True, diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index dd305a78dcf716b62a444e5f9dcc01c708c522be..cb7b5cf462408a2db220987676d2e51629b94ab0 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -9,11 +9,17 @@ py_library( name = "tfe", srcs = ["tfe.py"], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":datasets", + ":evaluator", + ":metrics", + ":network", ":saver", ":summary_writer", "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:numerics", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:util", "//tensorflow/python/eager:backprop", @@ -31,6 +37,7 @@ cuda_py_test( additional_deps = [ ":tfe", "//tensorflow/python:array_ops", + "//tensorflow/python:metrics", "//tensorflow/python:math_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:platform_test", @@ -43,8 +50,10 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/eager:context", @@ -57,10 +66,11 @@ py_test( srcs_version = "PY2AND3", deps = [ ":datasets", - "//tensorflow/contrib/data", + "//tensorflow/python:dtypes", "//tensorflow/python:math_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python/data", "//tensorflow/python/eager:test", - "//third_party/py/numpy", ], ) @@ -69,7 +79,11 @@ py_library( srcs = ["saver.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python/eager:context", ], ) @@ -81,7 +95,8 @@ cuda_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:platform_test", + "//tensorflow/python/eager:graph_callable", + "//tensorflow/python/eager:test", "//tensorflow/python:variables", ], ) @@ -92,12 +107,14 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/summary:gen_summary_ops", - "//tensorflow/contrib/summary:summary_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:state_ops", "//tensorflow/python:summary_op_util", - "//tensorflow/python:training", + "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", ], ) @@ -115,6 +132,108 @@ cuda_py_test( ], ) +py_library( + name = "metrics", + srcs = [ + "metrics.py", + "metrics_impl.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/summary:summary_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", + ], +) + +py_test( + name = "metrics_test", + srcs = ["metrics_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":metrics", + "//tensorflow/contrib/summary:summary_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "evaluator", + srcs = [ + "evaluator.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + ":datasets", + ":metrics", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", + ], +) + +py_test( + name = "evaluator_test", + srcs = ["evaluator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":evaluator", + ":metrics", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "network", + srcs = ["network.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:layers_base", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:util", + "@six_archive//:six", + ], +) + +py_test( + name = "network_test", + srcs = ["network_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":network", + "//tensorflow/python:constant_op", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python/eager:test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 9973f4eee227b6c039aab58cfa40efe629fb2312..f83c4704117f8d77812b37e116e35b97905b5f8e 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Support for tf.contrib.data when eager execution is enabled.""" +"""Iteration over tf.data.Datasets when eager execution is enabled.""" from __future__ import absolute_import from __future__ import division @@ -23,6 +23,8 @@ import threading from tensorflow.python.data.util import nest from tensorflow.python.eager import context from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops @@ -39,20 +41,23 @@ def _iterator_shared_name(): class Iterator(object): - """An iterator producing tf.Tensor objects from a tf.contrib.data.Dataset.""" + """An iterator producing tf.Tensor objects from a tf.data.Dataset.""" def __init__(self, dataset): """Creates a new iterator over the given dataset. For example: ```python - dataset = tf.contrib.data.Dataset.range(4) + dataset = tf.data.Dataset.range(4) for x in Iterator(dataset): print(x) ``` + Tensors produced will be placed on the device on which this iterator object + was created. + Args: - dataset: A `tf.contrib.data.Dataset` object. + dataset: A `tf.data.Dataset` object. Raises: RuntimeError: When invoked without eager execution enabled. @@ -60,22 +65,27 @@ class Iterator(object): if not context.in_eager_mode(): raise RuntimeError( - "{} objects only make sense when eager execution is enabled".format( - type(self))) - ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access - self._output_types = dataset.output_types - self._flat_output_types = nest.flatten(dataset.output_types) - self._flat_output_shapes = nest.flatten(dataset.output_shapes) - self._resource = gen_dataset_ops.iterator( - container="", - shared_name=_iterator_shared_name(), - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - gen_dataset_ops.make_iterator(ds_variant, self._resource) + "{} objects can only be used when eager execution is enabled, use " + "tf.data.Dataset.make_iterator or " + "tf.data.Dataset.make_one_shot_iterator for graph construction". + format(type(self))) + with ops.device("/device:CPU:0"): + ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access + self._output_types = dataset.output_types + self._flat_output_types = nest.flatten(dataset.output_types) + self._flat_output_shapes = nest.flatten(dataset.output_shapes) + self._resource = gen_dataset_ops.iterator( + container="", + shared_name=_iterator_shared_name(), + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + gen_dataset_ops.make_iterator(ds_variant, self._resource) + self._device = context.context().device_name def __del__(self): if self._resource is not None: - resource_variable_ops.destroy_resource_op(self._resource) + with ops.device("/device:CPU:0"): + resource_variable_ops.destroy_resource_op(self._resource) self._resource = None def __iter__(self): @@ -87,10 +97,19 @@ class Iterator(object): def next(self): """Return the next tf.Tensor from the dataset.""" try: - ret = gen_dataset_ops.iterator_get_next( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - return nest.pack_sequence_as(self._output_types, ret) + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + with ops.device("/device:CPU:0"): + ret = gen_dataset_ops.iterator_get_next( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) except errors.OutOfRangeError: raise StopIteration + # Copies tensors from CPU to the current device if necessary. + # TODO(rohanj): This should be replaced by the mechanism to have the + # runtime's threads copy tensors to the destination device. + with ops.device(self._device): + ret = [array_ops.identity(x) for x in ret] + return nest.pack_sequence_as(self._output_types, ret) diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index a2da6b28c6bdbfa0f6a4ed5d303aa4a36b81fc19..c924d81c9d85e638e4f35f260664c0ee7d03257e 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -16,10 +16,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data import Dataset from tensorflow.contrib.eager.python import datasets +from tensorflow.python.data import Dataset from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops class IteratorTest(test.TestCase): @@ -69,6 +72,23 @@ class IteratorTest(test.TestCase): got2 = [x.numpy() for x in datasets.Iterator(ds)] self.assertAllEqual(got1, got2) + def testPyFunc(self): + + def my_map(inp): + return [[x + 1 for x in inp]] + + ds = Dataset.range(4).map( + lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64)) + got = [x.numpy() for x in datasets.Iterator(ds)] + self.assertAllEqual([[1], [2], [3], [4]], got) + + def testTensorsPlacedOnDevice(self): + ds = Dataset.from_tensors([0., 1.]) + with ops.device(test.gpu_device_name()): + x = datasets.Iterator(ds).next() + x = math_ops.add(x, x) + self.assertAllEqual([0., 2.], x.numpy()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..633c747e5ec6e371e1232fd41bc094540952c7c2 --- /dev/null +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -0,0 +1,345 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class Evaluator holds Metrics for the duration of an evaluation run.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.eager.python import datasets +from tensorflow.contrib.eager.python import metrics +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops + + +class Evaluator(object): + """This holds and updates Metrics for the duration of a single eval run. + + Usage: + evaluator = my_model.evaluator() # or MyEvaluator(my_model) + for example_batch in ...: + evaluator(example_batch) + results = evaluator.all_metric_results(optional_summary_writer) + + Or, if you are getting your examples from a tf.data.Dataset, you can use + the evaluate_on_dataset() method. + + Implementers of Evaluators should + (a) Call `track_metric()` and/or `track_evaluator()` in __init__(). + (b) Override the `call()` method. It will be passed the output of the + model's `eval_data()` method, and should call its contained metrics + (treating them as callables) and any child Evaluators (using their + call() method to avoid calling eval_data() again). + + Args: + model: A `Model` object with an `eval_data()` method. + """ + + def __init__(self, model): + self._model = model + self._metrics = {} + self._evaluators = {} + if context.in_graph_mode(): + self.call = function.defun(self.call) + + # ---- API for users ---- + def __call__(self, *args, **kwargs): + """Update metrics with a minibatch of input examples. + + Args: + *args: + **kwargs: Arguments representing an input mini-batch of examples to + pass to self.model.eval_data(). + + Returns: + The op to execute or None if executing eagerly. + """ + return self.call(self._model.eval_data(*args, **kwargs)) + + def init_variables(self): + """Return an op for initializing all contained uninitialized variables. + + Only for graph execution. Should be called after variables are created + in the first execution of __call__(). + + Returns: + An op. + + Raises: + RuntimeError: if eager execution is enabled. + + @compatibility(eager) + Only for graph execution. + @end_compatibility + """ + if context.in_eager_mode(): + raise RuntimeError("Evaluator.init_variables() not needed when " + "eager execution is enabled.") + return control_flow_ops.group([m.init_variables() for _, m in self.metrics]) + + def all_metric_results(self): # TODO(josh11b): Add optional summary_writer. + """Returns dict mapping metric name -> value.""" + results = {} + for name, metric in six.iteritems(self._metrics): + results[name] = metric.result() + for prefix, evaluator in six.iteritems(self._evaluators): + for name, metric in six.iteritems(evaluator._metrics): # pylint: disable=protected-access + results[prefix + "/" + name] = metric.result() + return results + + def evaluate_on_dataset(self, dataset, *args, **kwargs): + """Convenience method for performing an eval on a Dataset. + + Args: + dataset: Dataset object with the input data to evaluate on. + *args: + **kwargs: Optional additional arguments to __call__(). + + Returns: + @compatibility(eager) + When eager execution is enabled, this returns the result of performing + an evaluation as a dictionary. With graph execution, this returns a tuple + (init_op, call_op, results_op) which may be executed using this code: + ```python + sess.run(init_op) + try: + while True: + sess.run(call_op) + except tf.errors.OutOfRangeError: + pass + return sess.run(results_op) # A dictionary + + # equivalently: + return evaluator.run_evaluation(init_op, call_op, results_op, sess=sess) + ``` + @end_compatibility + """ + # TODO(josh11b): Add optional summary_writer. + if context.in_graph_mode(): + call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), + *args, **kwargs) + init_op = self.init_variables() + results_op = self.all_metric_results() + return (init_op, call_op, results_op) + # Eager case + for example in datasets.Iterator(dataset): + self.__call__(example, *args, **kwargs) + return self.all_metric_results() + + @staticmethod + def run_evaluation(init_op, call_op, results_op, sess=None): + """Convenience method for running the ops returned by evaluate_on_dataset. + + Args: + init_op: An op that initializes/resets evaluation state. + call_op: An op that updates evaluation state on a mini-batch of examples. + Must generate an tf.errors.OutOfRangeError when done. + results_op: A dictionary of tensors that compute the final evaluation + results from the evaulation state. + sess: The Session to run the evaluation in. Defaults to the default + Session. + + Returns: + A dictionary of values, parallel to results_op. + + Raises: + RuntimeError: if eager execution is enabled. + + @compatibility(eager) + Only for graph execution. + @end_compatibility + """ + if context.in_eager_mode(): + raise RuntimeError("Evaluator.run_evaluation() not supported when " + "eager execution is enabled.") + sess = sess or ops.get_default_session() + sess.run(init_op) + try: + while True: + sess.run(call_op) + except errors_impl.OutOfRangeError: + pass + return sess.run(results_op) + + # ---- To be implemented by descendants --- + def call(self, eval_data): + """Update metrics using the output of self.model. + + Note: This function is executed as a graph function in graph mode. + This means: + a) Operations on the same resource are executed in textual order. + This should make it easier to do things like add the updated + value of a variable to another, for example. + b) You don't need to worry about collecting the update ops to execute. + All update ops added to the graph by this function will be executed. + As a result, code should generally work the same way with graph or + eager execution. + + Args: + eval_data: The output of self.model.eval_data() on a mini-batch of + examples. + """ + raise NotImplementedError("Evaluators must define a call member function.") + + # ---- For use by descendants --- + @property + def model(self): + return self._model + + def track_metric(self, metric): + """Add a Metric to be tracked. + + Metrics can only be tracked by one `Evaluator`. Metrics must be + tracked or they will not appear in `all_metric_results()`. + + Args: + metric: A `Metric` object. + + Returns: + The `metric` passed into this function. + + Raises: + RuntimeError: If called before __init__. + TypeError: If `metric` is not of the correct type. + ValueError: If there is a name collision between Metrics or `metric` + has already been added to another `Evaluator`. + """ + if not hasattr(self, "_metrics"): + raise RuntimeError( + "Need to call Evaluator.__init__ before adding metrics") + if not isinstance(metric, metrics.Metric): + raise TypeError( + "Evaluator.track_metric() passed type %s, not a tfe.metrics.Metric" % + (type(metric),)) + if metric.name in self._metrics: + if metric is self._metrics[metric.name]: + return metric + raise ValueError( + "Attempt to add two Metrics with the name '%s' to the same Evaluator " + "'%s'" % (metric.name, self.name)) + # pylint: disable=protected-access + if hasattr(metric, "_added_to_an_evaluator"): + raise ValueError("Metric %s already added to Evaluator %s" % + (metric.name, metric._added_to_an_evaluator)) + metric._added_to_an_evaluator = self.__class__.__name__ + # pylint: enable=protected-access + self._metrics[metric.name] = metric + return metric + + def track_evaluator(self, prefix, evaluator): + """Add a contained `Evaluator`. + + This is for delegating to another `Evaluator`, e.g. for when you have a + model with multiple heads. Users should manually invoke the child + `Evaluator`'s `call` method from their `call` method. + + Args: + prefix: A string. Metrics from `evaluator` are exported with this + prefix and a '/'. + evaluator: An `Evaluator` object. + + Returns: + The value of `evaluator` passed into this function. + + Raises: + RuntimeError: If called before __init__. + TypeError: If `evaluator` is not of the correct type. + ValueError: If an `Evaluator` has already been added with that `prefix`. + """ + if not hasattr(self, "_evaluators"): + raise RuntimeError( + "Need to call Evaluator.__init__ before adding evaluators") + if not isinstance(evaluator, Evaluator): + raise TypeError( + "Evaluator.track_evaluator() passed type %s, not a tfe.Evaluator." % + (type(evaluator),)) + if prefix in self._evaluators: + if evaluator is self._evaluators[prefix]: + return evaluator + raise RuntimeError( + "Attempt to add two Evaluators with the same prefix '%s'." % prefix) + self._evaluators[prefix] = evaluator + return evaluator + + @property + def metric_variables(self): + v = [] + for metric in six.itervalues(self._metrics): + v += metric.variables + for evaluator in six.itervalues(self._evaluators): + v += evaluator.metric_variables + return v + + @property + def metrics(self): + """Returns a list of (prefix, metric) pairs.""" + m = [] + for metric in six.itervalues(self._metrics): + m.append(("", metric)) + for prefix, evaluator in six.iteritems(self._evaluators): + m += [(prefix + "/" + p, m) for p, m in evaluator.metrics] + return m + + +class SparseSoftmaxEvaluator(Evaluator): + """Evaluator for a sparse softmax model. + + Computes a standard set of metrics for single-label, multi-class + models. + + Args: + model: A `SparseSoftmaxModel` object or a `Model` whose `eval_data()` + method produces a `dict` containing values for the loss, true + label, predicted class, and optional weights. + loss_key: Optional key for looking up the value of the loss in the + `eval_data()` dict. Defaults to "loss". + label_key: Optional key for looking up the value of the label in the + `eval_data()` dict. Defaults to "label". + predicted_class_key: Optional key for looking up the value of the + predicted class in the `eval_data()` dict. Defaults to "predicted_class". + weights_key: Optional key for looking up the value of the weights + in the `eval_data()` dict. Defaults to "weights". Note that weights + are optional, and default to 1 if not present in `eval_data`. + """ + + def __init__(self, model, loss_key="loss", label_key="label", + predicted_class_key="predicted_class", weights_key="weights"): + super(SparseSoftmaxEvaluator, self).__init__(model) + # TODO(josh11b): Expand this to include everything from the standard + # SparseSoftmax Head. + self.avg_loss = self.track_metric(metrics.Mean("Avg Loss")) + self.accuracy = self.track_metric(metrics.Accuracy()) + self.loss_key = loss_key + self.label_key = label_key + self.predicted_class_key = predicted_class_key + self.weights_key = weights_key + + def call(self, eval_data): + """Update metrics for `eval_data` dict (described above).""" + weights = eval_data.get(self.weights_key, None) + if weights is None: + self.avg_loss(eval_data[self.loss_key]) + self.accuracy(eval_data[self.label_key], + eval_data[self.predicted_class_key]) + else: + self.avg_loss(eval_data[self.loss_key], weights=weights) + self.accuracy(eval_data[self.label_key], + eval_data[self.predicted_class_key], + weights=weights) diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4652a6908126fbf4eae34c0892a4d26a3ea791fc --- /dev/null +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -0,0 +1,160 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class Evaluator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.eager.python import evaluator +from tensorflow.contrib.eager.python import metrics +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test + + +class IdentityModel(object): + + def eval_data(self, d): + return d + + +class PrefixLModel(object): + + def eval_data(self, d): + return {"l_" + key: d[key] for key in d} + + +class SimpleEvaluator(evaluator.Evaluator): + + def __init__(self, model): + super(SimpleEvaluator, self).__init__(model) + self.mean = self.track_metric(metrics.Mean("mean")) + + def call(self, eval_data): + self.mean(eval_data) + + +class DelegatingEvaluator(evaluator.Evaluator): + + def __init__(self, model): + super(DelegatingEvaluator, self).__init__(model) + self.sub = self.track_evaluator("inner", SimpleEvaluator(model)) + self.mean = self.track_metric(metrics.Mean("outer-mean")) + + def call(self, eval_data): + # Keys here come from PrefixLModel, which adds "l_". + self.mean(eval_data["l_outer"]) + self.sub.call(eval_data["l_inner"]) + + +# pylint: disable=not-callable +class EvaluatorTest(test.TestCase): + + def testSimple(self): + e = SimpleEvaluator(IdentityModel()) + e(3.0) + e([5.0, 7.0, 9.0]) + results = e.all_metric_results() + self.assertEqual(set(["mean"]), set(results.keys())) + self.assertEqual(6.0, results["mean"].numpy()) + + def testComposition(self): + e = DelegatingEvaluator(PrefixLModel()) + e({"inner": 2.0, "outer": 100.0}) + e({"inner": 4.0, "outer": 1000.0}) + results = e.all_metric_results() + self.assertEqual(set(["inner/mean", "outer-mean"]), set(results.keys())) + self.assertEqual(3.0, results["inner/mean"].numpy()) + self.assertEqual(550.0, results["outer-mean"].numpy()) + + def testMetricVariables(self): + e = DelegatingEvaluator(PrefixLModel()) + e({"inner": 2.0, "outer": 100.0}) + prefix_count = {} + for v in e.metric_variables: + p = v.name.split("/")[0] + prefix_count[p] = prefix_count.get(p, 0) + 1 + self.assertEqual({"outer_mean": 2, "mean": 2}, prefix_count) + + def testDatasetEager(self): + e = SimpleEvaluator(IdentityModel()) + ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) + results = e.evaluate_on_dataset(ds) + self.assertEqual(set(["mean"]), set(results.keys())) + self.assertEqual(6.0, results["mean"].numpy()) + + def testDatasetGraph(self): + with context.graph_mode(), self.test_session(): + e = SimpleEvaluator(IdentityModel()) + ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) + init_op, call_op, results_op = e.evaluate_on_dataset(ds) + results = e.run_evaluation(init_op, call_op, results_op) + self.assertEqual(set(["mean"]), set(results.keys())) + self.assertEqual(6.0, results["mean"]) + + def testModelProperty(self): + m = IdentityModel() + e = SimpleEvaluator(m) + self.assertIs(m, e.model) + + def testMetricsProperty(self): + e = DelegatingEvaluator(PrefixLModel()) + names = set([(p, m.name) for p, m in e.metrics]) + self.assertEqual(set([("", "outer-mean"), ("inner/", "mean")]), names) + + def testSharedMetric(self): + + class MetricArgEvaluator(evaluator.Evaluator): + + def __init__(self, model, m): + super(MetricArgEvaluator, self).__init__(model) + self.m = self.track_metric(m) + + metric = metrics.Mean("mean") + model = IdentityModel() + e = MetricArgEvaluator(model, metric) + with self.assertRaisesRegexp(ValueError, "already added"): + MetricArgEvaluator(model, metric) + del e + + def testMetricTrackedTwice(self): + + class MetricTwiceEvaluator(evaluator.Evaluator): + + def __init__(self, model): + super(MetricTwiceEvaluator, self).__init__(model) + self.m = self.track_metric(metrics.Mean("mean")) + self.track_metric(self.m) # okay to track same metric again + + MetricTwiceEvaluator(IdentityModel()) + + +class SparseSoftmaxEvaluatorTest(test.TestCase): + + def testSimple(self): + e = evaluator.SparseSoftmaxEvaluator(IdentityModel()) + e({e.loss_key: 1.0, e.label_key: 5, e.predicted_class_key: 5}) + e({e.loss_key: [0.0, 3.0, 4.0], + e.label_key: [1, 2, 3], + e.predicted_class_key: [1, 1, 3]}) + results = e.all_metric_results() + self.assertEqual(set(["Avg Loss", "Accuracy"]), set(results.keys())) + self.assertEqual(2.0, results["Avg Loss"].numpy()) + self.assertEqual(0.75, results["Accuracy"].numpy()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/metrics.py b/tensorflow/contrib/eager/python/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3100427376ddd480b50d967cf53e7831aaefb2 --- /dev/null +++ b/tensorflow/contrib/eager/python/metrics.py @@ -0,0 +1,26 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Metrics namespace.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint:disable=wildcard-import +from tensorflow.contrib.eager.python.metrics_impl import * +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['Accuracy', 'Mean', 'Metric'] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba653af4a2465a17a17ff4ff019e69476f6434e --- /dev/null +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -0,0 +1,304 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Metrics classes for computing the output of an evaluation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tensorflow.contrib.summary import summary_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope + + +_to_replace = re.compile("[^A-Za-z0-9.]") + + +class Metric(object): + """A metric holds state for aggregating statistics over an evaluation run. + + Example use with eager execution: + + ```python + m = SomeMetric(...) + for input in ...: + m(input) + print(m.result()) + ``` + + Example use with graph execution: + + ```python + m = SomeMetric(...) + m_placeholder = tf.placeholder(...) + m_update = m(m_placeholder) + # Variables defined in first call, so get the initialization op afterwards. + m_init = m.init_variables() # or tf.global_variables_initializer() + m_result = m.result() + with tf.Session() as sess: + sess.run(m_init) + for input in ...: + sess.run(m_update, feed_dict={m_placeholder: input}) + print(sess.run(m_result)) + ``` + + Descendants will implement: + * `build()`: All variables should be created in this method, by calling + `self.add_variable()` as in: `self.var = self.add_variable(...)` + build() will be called in the first invocation of `__call__()`, with + the same arguments passed `call()`. + * `call()`: Has all updates to variables, as in: + self.var.assign_add(...) + * `result()`: Computes and returns a final value for the metric + from the variables in `self`. + + Decendants may override `aggregate()`, but usually won't need to. It + adds in the state from a list of metrics of the same type as `self`. + (Default is to sum all the variables.) Note that users should not call + `aggregate()`, it is for use by TensorFlow infrastructure. + """ + + def __init__(self, name=None): + self._built = False + self._vars = [] + self._initial_values = {} + self._updates = [] + name = name or self.__class__.__name__ + # Replace things like spaces in name to create a valid scope name. + scope_name = _to_replace.sub("_", name) + # We create the variable scope now to get the unique name that will + # be used as a variable prefix when build() calls add_variable(). + with variable_scope.variable_scope( + scope_name, use_resource=True, reuse=False) as scope: + pos = scope.name.rfind(scope_name) + self._name = name + scope.name[pos + len(scope_name):] + self._scope = scope + if context.in_graph_mode(): + # We make self.call() into a graph callable here, so that we can + # return a single op that performs all of the variable updates. + self._construction_scope = ops.get_default_graph().as_default + self.call = function.defun(self.call) + else: + self._construction_scope = context.eager_mode + + # ---- API for users ---- + def __call__(self, *args, **kwargs): + """Returns op to execute to update this metric for these inputs. + + Returns None if eager execution is enabled. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric, passed on to `call()`. + """ + if not self._built: + with variable_scope.variable_scope( + self._scope), self._construction_scope(): + self.build(*args, **kwargs) + self._built = True + return self.call(*args, **kwargs) + + @property + def name(self): + return self._name + + @property + def variables(self): + return self._vars + + def init_variables(self): + """Initializes this Metric's variables. + + Should be called after variables are created in the first execution + of `__call__()`. If using graph execution, the return value should be + `run()` in a session before running the op returned by `__call__()`. + (See example above.) + + Returns: + If using graph execution, this returns an op to perform the + initialization. Under eager execution, the variables are reset to their + initial values as a side effect and this function returns None. + """ + if context.in_graph_mode(): + return control_flow_ops.group([v.initializer for v in self._vars]) + for v in self._vars: + v.assign(self._initial_values[v]) + + # ---- To be implemented by descendants --- + def build(self, *args, **kwargs): + """Method to create variables. + + Called by `__call__()` before `call()` for the first time. + + Args: + *args: + **kwargs: The arguments to the first invocation of `__call__()`. + `build()` may use the shape and/or dtype of these arguments + when deciding how to create variables. + """ + raise NotImplementedError("Metrics must define a build() member function") + + def call(self, *args, **kwargs): + """Accumulates statistics for the metric. Users should use __call__ instead. + + Note: This function is executed as a graph function in graph mode. + This means: + a) Operations on the same resource are executed in textual order. + This should make it easier to do things like add the updated + value of a variable to another, for example. + b) You don't need to worry about collecting the update ops to execute. + All update ops added to the graph by this function will be executed. + As a result, code should generally work the same way with graph or + eager execution. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric, as passed to + `__call__()`. + """ + raise NotImplementedError("Metrics must define a call() member function") + + def result(self): # TODO(josh11b): Add an optional summary_writer parameter. + """Computes and returns a final value for the metric.""" + raise NotImplementedError("Metrics must define a result() member function") + + # We can support two different strategies of for doing data-parallel + # distributed metric computations: + # * Put metric variables on the first device and rely on small + # bandwidth needed to do updates. (Doesn't require any particular + # code in Metric implementations.) + # * Ask each type of metric to define an aggregation method to run + # at the end of eval to merge across devices. Note: this is good + # for the use case where they want to record the metric's state + # for each example and then later decide which examples they want + # to aggregate over. (Recommended -- not too much harder and adds + # flexibility over previous option.) + # I'm going with the second strategy since we can define a default + # implementation of aggregate() that will work for most descendants. + def aggregate(self, metrics): + """Adds in the state from a list of metrics. + + Default implementation sums all the metric variables. + + Args: + metrics: A list of metrics with the same type as `self`. + + Raises: + ValueError: If metrics contains invalid data. + """ + for m in metrics: + if type(self) != type(m): # pylint: disable=unidiomatic-typecheck + raise TypeError("All metrics must be the same type, '%s' != '%s'." % + (type(self), type(m))) + # pylint: disable=protected-access + for i in range(len(self._vars)): + if any(m._vars[i].name != self._vars[i].name for m in metrics): + raise ValueError("All metrics must have variables in the same order.") + self._vars[i].assign_add(math_ops.add_n([m._vars[i] for m in metrics])) + # pylint: enable=protected-access + + # ---- For use by descendants --- + def add_variable(self, name, shape=None, dtype=None, initializer=None): + """***Only for use by descendants of Metric***.""" + if self._built: + raise RuntimeError("Can't call add_variable() except in build().") + v = variable_scope.get_variable(name, shape, dtype, initializer, + trainable=False, use_resource=True) + self._vars.append(v) + if context.in_eager_mode(): + self._initial_values[v] = v.value() + return v + + +class Mean(Metric): + """Computes the (weighted) mean of the given values.""" + # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? + # Or defaults to type of the input if it is tf.float32, else tf.float64? + + def __init__(self, name=None, dtype=dtypes.float64): + super(Mean, self).__init__(name=name) + self.dtype = dtype + + def build(self, *args, **kwargs): + # build() does not use call's arguments, by using *args, **kwargs + # we make it easier to inherit from Mean(). + del args, kwargs + self.numer = self.add_variable(name="numer", shape=(), + dtype=self.dtype, + initializer=init_ops.zeros_initializer) + self.denom = self.add_variable(name="denom", shape=(), + dtype=self.dtype, + initializer=init_ops.zeros_initializer) + + def call(self, values, weights=None): + """Accumulate statistics for computing the mean. + + For example, if values is [1, 3, 5, 7] then the mean is 4. + If the weights were specified as [1, 1, 0, 0] then the mean would be 2. + + Args: + values: Tensor with the per-example value. + weights: Optional weighting of each example. Defaults to 1. + """ + if weights is None: + self.denom.assign_add( + math_ops.cast(array_ops.identity(array_ops.size(values)), self.dtype)) + values = math_ops.reduce_sum(values) + self.numer.assign_add(math_ops.cast(values, self.dtype)) + else: + weights = math_ops.cast(weights, self.dtype) + self.denom.assign_add(math_ops.reduce_sum(weights)) + values = math_ops.cast(values, self.dtype) * weights + self.numer.assign_add(math_ops.reduce_sum(values)) + + def result(self): + t = self.numer / self.denom + summary_ops.scalar(name=self.name, tensor=t) + return t + + +class Accuracy(Mean): + """Calculates how often `predictions` matches `labels`.""" + + def __init__(self, name=None, dtype=dtypes.float64): + super(Accuracy, self).__init__(name=name, dtype=dtype) + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + For example, if labels is [1, 2, 3, 4] and predictions is [0, 2, 3, 4] + then the accuracy is 3/4 or .75. If the weights were specified as + [1, 1, 0, 0] then the accuracy would be 1/2 or .5. + + `labels` and `predictions` should have the same shape and type. + + Args: + labels: Tensor with the true labels for each example. One example + per element of the Tensor. + predictions: Tensor with the predicted label for each example. + weights: Optional weighting of each example. Defaults to 1. + """ + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, dtypes.float64) + super(Accuracy, self).call(matches, weights=weights) diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py new file mode 100644 index 0000000000000000000000000000000000000000..336ce9d307cd9e1afae2417c252ae98375a86ad9 --- /dev/null +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -0,0 +1,164 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Metrics.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensorflow.contrib.eager.python import metrics +from tensorflow.contrib.summary import summary_ops +from tensorflow.core.util import event_pb2 +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import tf_record +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import gfile +from tensorflow.python.training import training_util + + +class MetricsTest(test.TestCase): + + def testMean(self): + m = metrics.Mean() + m([1, 10, 100]) + m(1000) + m([10000.0, 100000.0]) + self.assertEqual(111111.0/6, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + + def testInitVariables(self): + m = metrics.Mean() + m([1, 10, 100, 1000]) + m([10000.0, 100000.0]) + self.assertEqual(111111.0/6, m.result().numpy()) + m.init_variables() + m(7) + self.assertEqual(7.0, m.result().numpy()) + + def testWriteSummaries(self): + m = metrics.Mean() + m([1, 10, 100]) + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name="t0").as_default(), summary_ops.always_record_summaries(): + m.result() # As a side-effect will write summaries. + + self.assertTrue(gfile.Exists(logdir)) + files = gfile.ListDirectory(logdir) + self.assertEqual(len(files), 1) + records = list( + tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + self.assertEqual(len(records), 2) + event = event_pb2.Event() + event.ParseFromString(records[1]) + self.assertEqual(event.summary.value[0].simple_value, 37.0) + + def testWeightedMean(self): + m = metrics.Mean() + m([1, 100, 100000], weights=[1, 0.2, 0.3]) + m([500000, 5000, 500]) # weights of 1 each + self.assertNear(535521/4.5, m.result().numpy(), 0.001) + + def testMeanDtype(self): + # Can override default dtype of float64. + m = metrics.Mean(dtype=dtypes.float32) + m([0, 2]) + self.assertEqual(1, m.result().numpy()) + self.assertEqual(dtypes.float32, m.dtype) + self.assertEqual(dtypes.float32, m.result().dtype) + + def testAccuracy(self): + m = metrics.Accuracy() + m([0, 1, 2, 3], [0, 0, 0, 0]) # 1 correct + m([4], [4]) # 1 correct + m([5], [0]) # 0 correct + m([6], [6]) # 1 correct + m([7], [2]) # 0 correct + self.assertEqual(3.0/8, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + + def testWeightedAccuracy(self): + m = metrics.Accuracy() + # 1 correct, total weight of 2 + m([0, 1, 2, 3], [0, 0, 0, 0], weights=[1, 1, 0, 0]) + m([4], [4], weights=[0.5]) # 1 correct with a weight of 0.5 + m([5], [0], weights=[0.5]) # 0 correct, weight 0.5 + m([6], [6]) # 1 correct, weight 1 + m([7], [2]) # 0 correct, weight 1 + self.assertEqual(2.5/5, m.result().numpy()) + + def testAccuracyDtype(self): + # Can override default dtype of float64. + m = metrics.Accuracy(dtype=dtypes.float32) + m([0, 0], [0, 1]) + self.assertEqual(0.5, m.result().numpy()) + self.assertEqual(dtypes.float32, m.dtype) + self.assertEqual(dtypes.float32, m.result().dtype) + + def testTwoMeans(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + m1 = metrics.Mean() + m1(0) + with self.assertRaises(ValueError): + m2 = metrics.Mean() + m2(2) + + def testNamesWithSpaces(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + m1 = metrics.Mean("has space") + m1(0) + self.assertEqual(m1.name, "has space") + self.assertEqual(m1.numer.name, "has_space/numer:0") + + def testGraph(self): + with context.graph_mode(), self.test_session() as sess: + m = metrics.Mean() + p = array_ops.placeholder(dtypes.float32) + accumulate = m(p) + init_op = m.init_variables() + init_op.run() + sess.run(accumulate, feed_dict={p: [1, 10, 100]}) + sess.run(accumulate, feed_dict={p: 1000}) + sess.run(accumulate, feed_dict={p: [10000, 100000]}) + self.assertAllEqual(m.result().eval(), 111111.0/6) + # Second init resets all the variables. + init_op.run() + sess.run(accumulate, feed_dict={p: 7}) + self.assertAllEqual(m.result().eval(), 7) + + def testTwoMeansGraph(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + with context.graph_mode(): + m1 = metrics.Mean() + m1(0) + with self.assertRaises(ValueError): + m2 = metrics.Mean() + m2(2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py new file mode 100644 index 0000000000000000000000000000000000000000..5b53a597f20a1cd0ba9be7f1d3a89e117cde66e8 --- /dev/null +++ b/tensorflow/contrib/eager/python/network.py @@ -0,0 +1,801 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A Network is a composition of Layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os +import weakref + +from tensorflow.python.eager import context +from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.framework import ops +from tensorflow.python.layers import base +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import training_util + +# pylint: disable=protected-access +# Explanation for protected-access disable: Network has lots of same-class and +# parent-class references across different objects, and some to private +# functions in base.py which should be reused. + + +_DeferredRestoration = collections.namedtuple( + + "_DeferredRestoration", + [ + # The map_func to use (either user-specified or the default). + "map_func", + # Boolean, True if the user specified an explicit map_func, for error + # messages. + "map_func_is_user", + # A mapping from checkpoint names to initial values of not-yet-created + # variables which should be restored. These values come from parsing a + # checkpoint. + "checkpointed_variables_to_restore", + # A mapping from checkpoint name to variable objects of variables which + # have already been restored, for error checking. + "restored_variables", + # The session to restore with (if in graph mode). + "session", + # Names of the Network where the restore was requested, for error + # messages. + "network_name", + "network_scope_name" + ]) + + +def _default_naming_conflict_error_message( + mapped_name, first_variable, second_variable, + network_name, network_scope_name): + return ( + ("The default checkpoint variable name mapping strategy for Network " + "'%s' resulted in a naming conflict. We attempted to strip off the " + "variable prefix for the Network ('%s'), but this resulted in two " + "variables named '%s' (originally '%s' and '%s'). This should only " + "happen when using variable sharing (i.e. the Network contains Networks " + "or Layers which were first added to another Network, and therefore " + "have that Network's variable prefix). One solution is to pass " + "`map_func=lambda n: n` to Network.save and Network.restore to use " + "fully qualified variable names in the checkpoint, although this will " + "require that the variable prefix of the Network being restored into " + "is also '%s'. You may alternatively write an arbitrary mapping.") + % ( + network_name, network_scope_name, mapped_name, + first_variable._shared_name, + second_variable._shared_name, network_scope_name + )) + + +def _restore_custom_map_func_error_message( + mapped_name, first_variable, second_variable, + network_name, network_scope_name): + return ( + ("The map_func passed to Network.restore for the Network '%s' " + "resulted in two variables named '%s' (originally '%s' and '%s'). Since " + "this is also an error on Network.save, this Network was " + "probably not saved with this map_func. Note that map_func " + "always maps from full variable names to checkpoint names; " + "there is no need to specify an inverse mapping.\n\n" + "Try stripping less from the variable names, or renaming parts " + "of the Network. For reference, variables created by sub-Layers " + "of this Network are prefixed with '%s', but if they are " + "re-used after being added to another Network they will have " + "that Network's full variable prefix instead.") % ( + network_name, mapped_name, + first_variable._shared_name, + second_variable._shared_name, + network_scope_name)) + + +def _make_custom_getter_for_deferred_restorations(): + """Returns a custom getter which searches `deferred_restorations`. + + Returns: A tuple of (_custom_getter, deferred_restorations) + _custom_getter: The getter which should be added to variable_scopes where + variables will be created. + deferred_restorations: A list for _DeferredRestoration objects. Typically + empty when the getter is set, and expanded as deferred restorations are + requested. All new deferred restorations should be appended to the end of + the list, where they will have priority over older deferred restorations. + """ + deferred_restorations = [] + + def _custom_getter(getter, name, shape=None, dtype=None, + initializer=None, + *args, **kwargs): + """A custom getter which processes deferred restorations.""" + # Iterate over restorations, newest first (newer restorations will take + # precedence over older restorations, just like with immediate restorations + # into existing variables). + delayed_restoration = None + found_value = False + value_to_restore = None + for delayed_restoration in reversed( + deferred_restorations): + checkpoint_name = delayed_restoration.map_func(name) + if (checkpoint_name + in delayed_restoration.checkpointed_variables_to_restore): + found_value = True + value_to_restore = ( + delayed_restoration.checkpointed_variables_to_restore[ + checkpoint_name]) + if found_value: + break + # value_to_restore may be False because this variable is not in any + # checkpoint we are restoring, or None because we have explicitly set it to + # None when it was previously fetched. In either case, we don't need to + # set an initializer. + if found_value and value_to_restore is not None: + initializer = value_to_restore + shape = None + variable = getter(name, shape=shape, dtype=dtype, initializer=initializer, + *args, **kwargs) + if found_value and value_to_restore is not None: + # Mark as already restored from this checkpoint. + delayed_restoration.checkpointed_variables_to_restore[ + checkpoint_name] = None + if context.in_graph_mode(): + delayed_restoration.session.run(variable.initializer) + if found_value: + # Error checking should run even if we've already restored a value. + if delayed_restoration.restored_variables.setdefault( + checkpoint_name, variable) is not variable: + # Naming conflict. We've tried to initialize two variables with the + # same value from the checkpoint. + if delayed_restoration.map_func_is_user: + raise ValueError( + _restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=delayed_restoration.restored_variables[ + checkpoint_name], + second_variable=variable, + network_name=delayed_restoration.network_name, + network_scope_name=delayed_restoration.network_scope_name)) + else: + raise ValueError( + _default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=delayed_restoration.restored_variables[ + checkpoint_name], + second_variable=variable, + network_name=delayed_restoration.network_name, + network_scope_name=delayed_restoration.network_scope_name)) + return variable + return _custom_getter, deferred_restorations + + +class Network(base.Layer): + """Represents the composition of a set of Layers. + + TODO(josh11b,ashankar): + - Should "trainable" be changeable on the Network object? + - Do we allow add_variable in Network? + - Detect layers used in __call__ that weren't registered with track_layer. + - Convert inputs to __call__ to tensors. + - Prevent variables from being created after the first __call__? + (Think about restoring from a checkpoint). + """ + + def __init__(self, name=None): + if isinstance(name, variable_scope.VariableScope): + raise ValueError("VariableScopes are not valid Network names.") + if name is not None and "/" in name: + raise ValueError( + "Forward slashes ('/') are not allowed in Network names.") + super(Network, self).__init__(name=name) + self._layers = [] + self._sub_layer_name_uids = collections.defaultdict(int) + # Initially None, but set to False for networks which are first built as + # top-level. + self._first_parent = None # A weak reference to our first parent. + self._non_network_sublayers = [] + self._owned_layers = {} + # The scope to use if we end up without a parent. + self._default_parent_variable_scope = variable_scope.get_variable_scope() + self._custom_getter, self._deferred_restorations = ( + _make_custom_getter_for_deferred_restorations()) + + def _init_set_name(self, name): + # Anonymous Networks (name=None) defer setting a final name until they are + # (1) added to another Network, or (2) built/called (where (2) is only used + # for a "top level" network). + # + # However, if we were provided an explicit name (name is not None), that + # will always be the final name of the Network; if it turns out not to be + # unique or if variable names can't be prefixed by it we will throw an + # error. + self._name = name + self._base_name = None + + def _finalize_name(self, parent_network): + if not self._name: + if not parent_network: + name_uid_map = base._get_default_graph_uid_map() + else: + name_uid_map = parent_network._sub_layer_name_uids + # Were were not passed a name explicitly (or it was blank), so this is an + # anonymous Network. We make up a unique name. + if parent_network: + avoid_names = parent_network._owned_layers + else: + avoid_names = None + self._name, self._base_name = self._make_unique_name( + name_uid_map=name_uid_map, avoid_names=avoid_names) + if self._first_parent is None or (self._first_parent # False = no parent + and self._first_parent() is None): + # Save a pointer to the parent Network so that we can later check that the + # scope name we get is correct. + if not parent_network: + self._first_parent = parent_network + else: + self._first_parent = weakref.ref(parent_network) + + def _set_scope(self, scope=None): + if self._scope is None: + if not self._first_parent: + first_parent = self._first_parent + else: + first_parent = self._first_parent() + if first_parent is None: + # If we were never added to another Network, or that Network has beed + # garbage collected before being called, then we're a top-level Network. + self._finalize_name( + # Use False to make sure the value sticks and we don't inherit a + # parent if we're added to a network later. + parent_network=False) + if scope is not None: + raise ValueError("Networks may not be created with explicit scopes.") + if first_parent: + first_parent._set_scope() + parent_scope = first_parent._scope + else: + parent_scope = self._default_parent_variable_scope + with variable_scope.variable_scope(parent_scope): + # Make sure variables with this prefix will be unique. + with variable_scope.variable_scope( + None, use_resource=True, default_name=self._name) as scope: + self._scope = scope + scope_name = scope.name + suffix_start = scope_name.rfind("/") + 1 + # rfind is -1 if there is no slash in the string, in which case the + # suffix starts at the beginning of the string (there is no prefix). + scope_suffix = scope_name[suffix_start:] + scope_prefix = scope_name[:suffix_start] + if scope_suffix != self._name: + raise ValueError( + ("A Network named '%s' already exists (or a variable_scope was " + "created with this name). Names must be unique.") % ( + self._name,)) + if (first_parent + and scope_prefix[:-1] != first_parent._scope.name): + raise ValueError( + ("Network variable names must match a nesting of sub-Network " + "names. Expected prefix '%s' from parent network, but got " + "'%s' when attempting to create a variable_scope for Network " + "'%s'. Likely an explicit variable_scope was inserted into " + "the nesting.") % ( + first_parent._scope.name, + scope_prefix[:-1], + self._name)) + elif not first_parent and scope_prefix: + # For the case when this Network is not nested inside any other + # Network, but is in a variable_scope. This is an error for now. + raise ValueError( + "Creating Networks inside named variable_scopes is currently " + "not supported (to ensure that variable names match the names " + "of Networks in which they were first created). To set " + "options, try `with tf.variable_scope(''):`. If this " + "limitation bothers you, please file a feature request.") + for non_network_sublayer in self._non_network_sublayers: + self._set_scope_for_nonnetwork_sublayer(non_network_sublayer) + + def _set_scope_for_nonnetwork_sublayer(self, sublayer): + if sublayer._scope is None: + if sublayer._first_parent is None: + constituent_first_parent = None + else: + constituent_first_parent = sublayer._first_parent() + if constituent_first_parent: + constituent_first_parent._set_scope() + parent_scope = constituent_first_parent._scope + else: + self._finalize_name(False) + raise ValueError( + ("The parent of a Layer added to Network %s was garbage collected " + "before the Layer was built. If this limitation bothers you " + "please, file a feature request.") % (self.name,)) + with variable_scope.variable_scope(parent_scope): + # Horrid hack to make Layer variable names which are direct + # sub-layers of Networks conform to the Network variable naming + # conventions. + with variable_scope.variable_scope( + None, use_resource=True, + default_name=sublayer.name) as sub_scope: + sublayer._scope = sub_scope + + @base.Layer.name.getter + def name(self): + if self._name is None: + raise ValueError( + "The network does not yet have a final name, but a name was " + "requested for it. Networks get a name when they are added to " + "another Network via track_layer, or when they are first " + "called/built.") + return self._name + + def track_layer(self, layer): + """Track a Layer in this Network. + + `Network` requires that all `Layer`s used in `call()` be tracked so that the + `Network` can export a complete list of variables. + + Args: + layer: A `tf.layers.Layer` object. + + Returns: + The passed in `layer`. + + Raises: + RuntimeError: If __init__ has not been called. + TypeError: If `layer` is the wrong type. + ValueError: If a `Layer` with the same name has already been added. + """ + if not hasattr(self, "_layers"): + raise RuntimeError("Need to call Network.__init__ before adding layers") + if not isinstance(layer, base.Layer): + raise TypeError( + "Network.track_layer() passed type %s, not a tf.layers.Layer" % + (type(layer),)) + if isinstance(layer, Network): + layer._finalize_name(parent_network=self) + else: + # `layer` is a non-Network, so it hasn't been named to follow Network + # conventions for contained Layers (i.e. the same conventions as for + # sub-Networks). This renaming is necessary to isolate Network variable + # naming from Layers constructed outside the Network and never added to it + # (because Layers are named globally). + if not layer.built: + if not hasattr(layer, "_first_parent"): + dereferenced_layer_first_parent = None + else: + dereferenced_layer_first_parent = layer._first_parent() + if dereferenced_layer_first_parent is None: + if layer._name != layer._base_name: + # If name and base_name do not match, then this Layer used anonymous + # naming and we have to rename it. Otherwise there's an explicit + # name, and we should respect it (subject to error checking). + layer._name, layer._base_name = layer._make_unique_name( + name_uid_map=self._sub_layer_name_uids, + avoid_names=self._owned_layers) + layer._first_parent = weakref.ref(self) + self._non_network_sublayers.append(layer) + if (not layer.built + and layer._first_parent + and self is layer._first_parent()): + if layer.name in self._owned_layers: + if self._owned_layers[layer.name] is layer: + return layer + raise ValueError( + "Attempt to add two Layers with the name '%s' to the same Network." + % (layer.name)) + self._owned_layers[layer.name] = layer + self._layers.append(layer) + return layer + + def get_layer(self, name=None, index=None): + """Get a contained `tf.layers.Layer` either by name or index. + + Args: + name: String matching one of the names of a contained `Layer`. Note that + the names of `Layer`s added to `Network`s may not be unique when doing + layer sharing (i.e. adding a `Layer` to this `Network` which was already + added to another `Network`). The lowest index `Layer` with a matching + name will be returned. + index: Integer in [0, number of layers). Layers are assigned an index + by the order they are added. + + Returns: + A `tf.layers.Layer` object. + + Raises: + ValueError: If neither or both of 'index' or 'name' is specified, or the + lookup failed. + """ + if index is not None: + if name is not None: + raise ValueError("Exactly one of 'index' or 'name' must be provided") + if len(self._layers) <= index: + raise ValueError("Was asked to retrieve layer at index " + str(index) + + " but model only has " + str(len(self._layers)) + + " layers.") + else: + return self._layers[index] + else: + if not name: + raise ValueError("Provide either a layer name or layer index.") + for layer in self._layers: + if layer.name == name: + return layer + raise ValueError("No such layer: " + name) + + # The following methods are for implementing the Layer interface. + + @property + def weights(self): + # TODO(josh11b): Should this return a set or perform de-duplication of + # variables in the case of shared layers/variables that appear in + # multiple places in the Network? + weights = [] + for layer in self._layers: + weights += layer.weights + return weights + + @property + def trainable_weights(self): + weights = [] + for layer in self._layers: + weights += layer.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for layer in self._layers: + weights += layer.non_trainable_weights + return weights + + @property + def trainable(self): + return True + + @trainable.setter + def trainable(self, value): + if not value: + # We believe it better to decide which layers & networks are trainable + # at the Trainer level than here. Otherwise you can run into trouble if a + # layer/network is shared between two models, but is trainable in one + # but not the other (like with adversarial networks). + raise AttributeError("cannot mark Network as not trainable") + + @property + def layers(self): + return self._layers + + def add_variable(self, name, shape, dtype=None, initializer=None, + regularizer=None, trainable=True, constraint=None): + raise RuntimeError( + "add_variable not supported in Network class yet. Please file an issue " + "at https://github.com/tensorflow/tensorflow/issues/new if this is " + "important to you") + + def _strip_variable_prefix(self, original_variable_name): + """The default map_func for saving or restoring variables. + + Strips the variable prefix for the Network on which save/restore was called, + and leaves other variable names fully qualified in the checkpoint. + + Args: + original_variable_name: The _shared_name of the variable (no :0 + suffix) to map. + Returns: + The checkpoint name of the variable. + """ + scope_name_with_slash = self.scope_name + "/" + if original_variable_name.startswith(scope_name_with_slash): + return original_variable_name[len(scope_name_with_slash):] + else: + return original_variable_name + + def save(self, save_path, global_step=None, map_func=None): + """Save variables from the Network to a checkpoint. + + Args: + save_path: Either a checkpoint prefix or the name of a directory to save + the checkpoint in (in which case the checkpoint will be named based on + the Network name). + global_step: The global step to use when naming the checkpoint. If None + (default), we will first try to get the default global step. If that + fails because no default global step exists, then the checkpoint is + created without a global step suffix. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. + Returns: + The checkpoint prefix for the saved checkpoint, which may be passed to + `Network.restore`. + Raises: + ValueError: If the Network has not yet been called, or if map_func results + in a name collision. + """ + if not self.built: + raise ValueError( + "Attempt to save the Network before it was first called. This means " + "variables have not yet been created, so there is nothing to save.") + self._set_scope() # scope_name should be available to map_funcs + if global_step is None: + global_step = training_util.get_global_step() + if os.path.isdir(save_path): + # If we were passed a directory, default to naming based on the Network + # name. + save_path = os.path.join(save_path, self.name) + user_map_func = map_func + if map_func is None: + map_func = self._strip_variable_prefix + variable_map = {} + for variable in self.variables: + mapped_name = map_func(variable._shared_name) + if variable_map.setdefault(mapped_name, variable) is not variable: + if user_map_func is None: + # Instead of erroring out, we could just re-try and silently use the + # full variable names in the checkpoint. This could be odd for deeply + # nested sub-Networks (since the full prefix from the nesting would + # get added), so for now we'll let the user deal with this case. + raise ValueError(_default_naming_conflict_error_message( + mapped_name=mapped_name, + first_variable=variable_map[mapped_name], + second_variable=variable, + network_name=self.name, + network_scope_name=self.scope_name)) + else: + # The user passed their own problematic map_func. + raise ValueError( + ("The map_func passed to Network.save for the Network '%s' " + "resulted in two variables named '%s' ('%s' and '%s'). Try " + "stripping less from the variable names, or renaming parts of " + "the Network. For reference, variables created by sub-Layers of " + "this Network are prefixed with '%s', but if they are re-used " + "after being added to another Network, they will have that " + "Network's full variable prefix instead.") % ( + self.name, mapped_name, + variable_map[mapped_name]._shared_name, + variable._shared_name, + self.scope_name)) + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + return saver_lib.Saver(variable_map).save( + sess=sess, save_path=save_path, write_meta_graph=False, + global_step=global_step) + + def _restore_existing_variables(self, save_path, map_func, user_map_func): + """Use a standard Saver to restore existing variables from a checkpoint. + + Args: + save_path: The checkpoint prefix or directory to read from. + map_func: The function to use when mapping from variable names to + checkpoint names. + user_map_func: The original map_func passed by the user, for error + checking. + Returns: + A dictionary mapping from checkpoint names to variable objects which have + been restored (for bookkeeping to avoid deferred restorations on these + variables). + Raises: + ValueError: If there is a name collision. + """ + existing_variables_by_checkpoint_name = {} + for variable in self.variables: + checkpoint_name = map_func(variable._shared_name) + if existing_variables_by_checkpoint_name.setdefault( + checkpoint_name, variable) is not variable: + if user_map_func is None: + raise ValueError(_default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=self.name, + network_scope_name=self.scope_name)) + else: + raise ValueError(_restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=self.name, + network_scope_name=self.scope_name)) + if existing_variables_by_checkpoint_name: + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore( + sess=sess, save_path=save_path) + return existing_variables_by_checkpoint_name + + def _set_restore_on_create(self, save_path, map_func, user_map_func, + existing_variables_by_checkpoint_name): + """If necessary, request deferred restorations of variables.""" + checkpoint_reader = checkpoint_utils.load_checkpoint(save_path) + checkpointed_variables_to_restore = {} + for checkpoint_name, _ in checkpoint_utils.list_variables(save_path): + if checkpoint_name in existing_variables_by_checkpoint_name: + # This variable was already created and restored. + continue + # Save the variable for later restoration in a custom getter. + checkpointed_variables_to_restore[checkpoint_name] = ( + checkpoint_reader.get_tensor(checkpoint_name)) + # Only set a deferred restoration if there are checkpoint variables which + # have not been assigned to existing variables. Note that this loses out on + # some opportunity for error checking, but avoids creating + # _DeferredRestoration objects once a Network has been built (so that + # restoring in a loop does not take increasing amounts of memory). + if checkpointed_variables_to_restore: + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + # We need a name for error messages. If we haven't been added to another + # Network yet, we're top-level. + self._finalize_name(False) + self._set_scope() + # Save a record of this restoration for use in the custom getter. + deferred_restoration = _DeferredRestoration( + map_func=map_func, + map_func_is_user=(user_map_func is not None), + checkpointed_variables_to_restore=checkpointed_variables_to_restore, + restored_variables={}, + session=sess, + network_name=self.name, + network_scope_name=self.scope_name) + self._deferred_restorations.append(deferred_restoration) + # Add the deferred registration to non-Network children, and request that + # Networks propagate the request to their children. + self._add_deferred_restoration(deferred_restoration) + + def _add_deferred_restoration(self, deferred_restoration): + """Add a deferred restoration to this Network and all children. + + Restorations which are requested later have higher priority, and the highest + priority matching restoration is applied to a variable when it is created. + + Args: + deferred_restoration: A _DeferredRestoration object. + """ + # Networks don't create variables at the moment, so this append isn't + # strictly necessary. We could get by with only adding deferred restorations + # to non-Network Layers. + self._set_scope() + # We use set_custom_getter because it avoids recursively calling up the + # variable_scope tree. We've done the tree traversal ourselves and have + # added the request to each Layer which needs it. + self._scope.set_custom_getter(self._custom_getter) + self._deferred_restorations.append(deferred_restoration) + for layer in self.layers: + if isinstance(layer, Network): + # For Networks, request that they propagate this deferred restoration + # to all of their children recursively. + layer._add_deferred_restoration(deferred_restoration) + else: + # For non-Network Layers, make sure they have a deferred restoration + # queue and a custom getter, then add our request to it. + if not hasattr(layer, "_custom_getter"): + assert not hasattr(layer, "_deferred_restorations") + layer._custom_getter, layer._deferred_restorations = ( + _make_custom_getter_for_deferred_restorations()) + self._set_scope_for_nonnetwork_sublayer(layer) + layer._scope.set_custom_getter(layer._custom_getter) + layer._deferred_restorations.append(deferred_restoration) + + def restore(self, save_path, map_func=None): + """Restore the Network from a checkpoint. + + If variables have already been created (typically when some or all of the + `Network` is built), they are assigned values from the checkpoint + immediately, overwriting any existing values (in graph mode the default + session is used for the assignments). + + If there are checkpoint entries which do not correspond to any existing + variables in the `Network`, these values are saved for deferred restoration; + their initial values will be the checkpointed values once they are + created. Requests for multiple deferred restorations behave the same way as + immediate restorations, in that later requests will take priority over + earlier requests relevant to the same variable. + + If this `Network` shares `Layer`s with another network, those `Layer`s will + also have their variables restored from the checkpoint. + + Args: + save_path: The return value of `Network.save`, or a directory to search + for a checkpoint. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. Note that this is the _same_ map_func as `Network.save`, not + an inverse mapping. + """ + self._finalize_name(parent_network=False) + self._set_scope() # scope_name should be available to map_funcs + if os.path.isdir(save_path): + # If we don't have a name yet, set no parent. + save_path = os.path.join(save_path, self.name) + user_map_func = map_func + if map_func is None: + map_func = self._strip_variable_prefix + # Step one is to restore any existing variables from the checkpoint. + existing_variables_by_checkpoint_name = self._restore_existing_variables( + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func) + # Step two is to set a custom getter which restores variables on creation, + # for those variables which have not been added to sub-Layers yet. + self._set_restore_on_create( + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func, + existing_variables_by_checkpoint_name=( + existing_variables_by_checkpoint_name)) + + # TODO(josh11b): Support other Layer methods needed for graph mode, such as for + # losses and updates + + +class Sequential(Network): + """Represents a linear sequence of Layers or functions. + + The output of each layer/function is provided as the input to the next. + The inputs passed to `__call__` are passed to the inputs of the first + Layer, and it returns the outputs of the last Layer. + + Args: + layers_funcs: An optional sequence where each element is either a + tf.layers.Layer object or a callable. + name: An optional string name to use for this Network. + """ + + def __init__(self, layers_funcs=None, name=None): + super(Sequential, self).__init__(name=name) + self._layers_funcs = [] + if layers_funcs: + for l in layers_funcs: + self.add(l) + + def add(self, layer_func): + if isinstance(layer_func, base.Layer): + args = estimator_util.fn_args(layer_func.call) + self.track_layer(layer_func) + elif callable(layer_func): + args = estimator_util.fn_args(layer_func) + else: + raise TypeError( + "Sequential.add() takes only tf.layers.Layer objects or callables; " + "not '%s' of type '%s'." % (layer_func, type(layer_func))) + self._layers_funcs.append((("training" in args), layer_func)) + + def call(self, inputs, training=None): + """Call each Layer in the order they were added.""" + # TODO(josh11b): Support "mode" and maybe other arguments + if training is None: + for _, l in self._layers_funcs: + inputs = l(inputs) + else: + for has_training_arg, l in self._layers_funcs: + if has_training_arg: + inputs = l(inputs, training) + else: + inputs = l(inputs) + return inputs diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c621f527c28306131bdba56d8427eaa787ba150b --- /dev/null +++ b/tensorflow/contrib/eager/python/network_test.py @@ -0,0 +1,1075 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc + +from tensorflow.contrib.eager.python import network +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import test_util +from tensorflow.python.layers import core +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import training_util + + +# pylint: disable=not-callable +class MyNetwork(network.Network): + + def __init__(self, name=None): + super(MyNetwork, self).__init__(name=name) + self.l1 = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.l1(x) + + +class NetworkTest(test.TestCase): + + def _save_modify_load_network_built(self, net, global_step=None): + checkpoint_directory = self.get_temp_dir() + checkpoint_path = net.save( + save_path=checkpoint_directory, global_step=global_step) + input_value = constant_op.constant([[42.0]]) + original_output = self.evaluate(net(input_value)) + for var in net.variables: + self.evaluate(var.assign(var + 1.)) + self.assertGreater( + self.evaluate(net(input_value)), + original_output) + # Either the returned explicit checkpoint path or the directory should work. + net.restore(save_path=checkpoint_directory) + self.assertAllEqual( + original_output, + self.evaluate(net(input_value))) + for var in net.variables: + self.evaluate(var.assign(var + 2.)) + net.restore(save_path=checkpoint_path) + self.assertAllEqual( + original_output, + self.evaluate(net(input_value))) + + @test_util.run_in_graph_and_eager_modes() + def testTrainableAttribute(self): + net = network.Network() + self.assertTrue(net.trainable) + with self.assertRaises(AttributeError): + net.trainable = False + self.assertTrue(net.trainable) + + @test_util.run_in_graph_and_eager_modes() + def testNetworkCall(self): + net = MyNetwork(name="abcd") + net(constant_op.constant([[2.0]])) # Force variables to be created. + self.assertEqual(1, len(net.trainable_variables)) + self.evaluate(net.trainable_variables[0].assign([[17.0]])) + # TODO(josh11b): Support passing Python values to networks. + result = net(constant_op.constant([[2.0]])) + self.assertEqual(34.0, self.evaluate(result)) + + @test_util.run_in_graph_and_eager_modes() + def testNetworkSaveRestoreAlreadyBuilt(self): + net = MyNetwork(name="abcd") + with self.assertRaisesRegexp( + ValueError, "Attempt to save the Network before it was first called"): + net.save(self.get_temp_dir()) + net(constant_op.constant([[2.0]])) + self.evaluate(net.trainable_variables[0].assign([[17.0]])) + self._save_modify_load_network_built(net, global_step=None) + self._save_modify_load_network_built(net, global_step=10) + + @test_util.run_in_graph_and_eager_modes() + def testSaveRestoreDefaultGlobalStep(self): + net = MyNetwork(name="abcd") + net(constant_op.constant([[2.0]])) + self.evaluate(net.variables[0].assign([[3.]])) + default_global_step = training_util.get_or_create_global_step() + self.evaluate(default_global_step.assign(4242)) + save_path = net.save(self.get_temp_dir()) + self.assertIn("abcd-4242", save_path) + + @test_util.run_in_graph_and_eager_modes() + def testNetworkSaveAndRestoreIntoUnbuilt(self): + save_dir = self.get_temp_dir() + net1 = MyNetwork() + test_input = constant_op.constant([[2.0]]) + net1(test_input) + self.evaluate(net1.trainable_variables[0].assign([[17.0]])) + save_path = net1.save(save_dir) + # With a pre-build restore we should have the same value. + net2 = MyNetwork() + net2.restore(save_path) + self.assertAllEqual(self.evaluate(net1(test_input)), + self.evaluate(net2(test_input))) + self.assertIsNot(net1.variables[0], net2.variables[0]) + self.assertAllEqual(self.evaluate(net1.variables[0]), + self.evaluate(net2.variables[0])) + + @test_util.run_in_graph_and_eager_modes() + def testLoadIntoUnbuiltSharedLayer(self): + + class Owner(network.Network): + + def __init__(self, name=None): + super(Owner, self).__init__(name=name) + self.first = self.track_layer(core.Dense( + 1, name="first_layer", use_bias=False)) + + def call(self, x): + return self.first(x) + + first_owner = Owner() + + class User(network.Network): + + def __init__(self, use_layer, name=None): + super(User, self).__init__(name=name) + self.first = self.track_layer(use_layer) + self.second = self.track_layer(core.Dense( + 1, name="second_layer", use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + class LikeUserButNotSharing(network.Network): + + def __init__(self, name=None): + super(LikeUserButNotSharing, self).__init__(name=name) + self.first = self.track_layer(core.Dense( + 1, name="first_layer", use_bias=False)) + self.second = self.track_layer(core.Dense( + 1, name="second_layer", use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + checkpoint_creator = LikeUserButNotSharing(name="checkpoint_creator") + one = constant_op.constant([[1.0]]) + checkpoint_creator(one) + self.assertEqual(2, len(checkpoint_creator.variables)) + self.evaluate(checkpoint_creator.variables[0].assign([[5.]])) + self.evaluate(checkpoint_creator.variables[1].assign([[6.]])) + # Re-map the variable names so that with default restore mapping we'll + # attempt to restore into the unbuilt Layer. + name_mapping = { + "checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel", + "checkpoint_creator/second_layer/kernel": "second_layer/kernel", + } + save_path = checkpoint_creator.save( + self.get_temp_dir(), + map_func=lambda full_name: name_mapping[full_name]) + load_into = User(use_layer=first_owner.first) + load_into.restore(save_path) + self.assertEqual(0, len(first_owner.variables)) + self.assertAllEqual(self.evaluate(checkpoint_creator(one)), + self.evaluate(load_into(one))) + self.assertEqual(1, len(first_owner.variables)) + self.assertAllEqual([[5.]], self.evaluate(load_into.variables[0])) + self.assertAllEqual([[6.]], self.evaluate(load_into.variables[1])) + first_owner(one) + self.assertAllEqual([[5.]], self.evaluate(first_owner.variables[0])) + + # Try again with a garbage collected parent. + first_owner = Owner() + load_into = User(use_layer=first_owner.first) + del first_owner + gc.collect() + def _restore_map_func(original_name): + if original_name.startswith("owner_1"): + return original_name.replace("owner_1", "owner_2") + else: + return "user_2/" + original_name + with self.assertRaisesRegexp(ValueError, "garbage collected"): + load_into.restore(save_path, map_func=_restore_map_func) + + @test_util.run_in_graph_and_eager_modes() + def testRestoreIntoSubNetwork(self): + + class Parent(network.Network): + + def __init__(self, name=None): + super(Parent, self).__init__(name=name) + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.first(self.second(x)) + + one = constant_op.constant([[3.]]) + whole_model_saver = Parent() + whole_model_saver(one) + self.evaluate(whole_model_saver.variables[0].assign([[15.]])) + self.evaluate(whole_model_saver.variables[1].assign([[16.]])) + whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir()) + + save_from = MyNetwork() + save_from(one) + self.evaluate(save_from.variables[0].assign([[5.]])) + checkpoint = save_from.save(self.get_temp_dir()) + save_into_parent = Parent() + save_into_parent.restore(whole_model_checkpoint) + save_into_parent.first.restore(checkpoint) + save_into_parent.first.restore(checkpoint) # deferred loading multiple + # times is fine + save_into_parent(one) # deferred loading + self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0])) + self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1])) + + # Try again with the opposite ordering, and we should get different results + # (deferred restoration should happen the same way non-deferred happens, + # with later restorations overwriting older ones). + save_into_parent = Parent() + save_into_parent.first.restore(checkpoint) # deferred loading multiple + # times is fine + save_into_parent.restore(whole_model_checkpoint) + save_into_parent(one) # deferred loading + # We've overwritten the sub-Network restore. + self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0])) + self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1])) + + self.evaluate(save_into_parent.variables[0].assign([[3.]])) + self.evaluate(save_into_parent.variables[1].assign([[4.]])) + save_into_parent.second.restore(checkpoint) + self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1])) + with self.assertRaisesRegexp(errors_impl.NotFoundError, + "not found in checkpoint"): + # The checkpoint is incompatible. + save_into_parent.restore(checkpoint) + + @test_util.run_in_graph_and_eager_modes() + def testCustomMapCollisionErrors(self): + + class Parent(network.Network): + + def __init__(self, name=None): + super(Parent, self).__init__(name=name) + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.first(self.second(x)) + + make_checkpoint = Parent() + one = constant_op.constant([[1.]]) + make_checkpoint(one) + self.evaluate(make_checkpoint.variables[0].assign([[2.]])) + self.evaluate(make_checkpoint.variables[1].assign([[3.]])) + with self.assertRaisesRegexp( + ValueError, + "The map_func passed to Network.save for the Network 'parent_1' " + "resulted in two variables named 'foo'"): + make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo") + checkpoint = make_checkpoint.first.save( + self.get_temp_dir(), map_func=lambda n: "foo") + loader = Parent() + loader.restore(checkpoint, map_func=lambda n: "foo") + with self.assertRaisesRegexp( + ValueError, + ("The map_func passed to Network.restore for the Network" + " 'parent_2' resulted in two variables named 'foo'")): + loader(one) + loader = Parent() + loader(one) + with self.assertRaisesRegexp( + ValueError, + ("The map_func passed to Network.restore for the Network" + " 'parent_3' resulted in two variables named 'foo'")): + loader.restore(checkpoint, map_func=lambda n: "foo") + + @test_util.run_in_graph_and_eager_modes() + def testDefaultMapCollisionErrors(self): + + one = constant_op.constant([[1.]]) + first = core.Dense(1, name="dense_1", use_bias=False) + first(one) + + class Parent(network.Network): + + def __init__(self, name=None): + super(Parent, self).__init__(name=name) + self.first = self.track_layer(first) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.first(self.second(x)) + + make_checkpoint = Parent() + one = constant_op.constant([[1.]]) + make_checkpoint(one) + self.evaluate(make_checkpoint.variables[0].assign([[2.]])) + self.evaluate(make_checkpoint.variables[1].assign([[3.]])) + with self.assertRaisesRegexp( + ValueError, + ("The default checkpoint variable name mapping strategy for Network " + "'parent_1' resulted in a naming conflict.")): + make_checkpoint.save(self.get_temp_dir()) + + class Compatible(network.Network): + + def __init__(self, name=None): + super(Compatible, self).__init__(name=name) + self.first = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.first(x) + + successful_checkpoint = Compatible() + successful_checkpoint(one) + self.evaluate(successful_checkpoint.variables[0].assign([[-1.]])) + checkpoint_path = successful_checkpoint.save(self.get_temp_dir()) + load_checkpoint = Parent() + load_checkpoint(one) + with self.assertRaisesRegexp( + ValueError, + ("The default checkpoint variable name mapping strategy for Network " + "'parent_2' resulted in a naming conflict.")): + load_checkpoint.restore(checkpoint_path) + + def testNoReferenceCyclesAfterCall(self): + + class ChildNetwork(network.Network): + + def __init__(self, name=None): + super(ChildNetwork, self).__init__(name=name) + + def call(self, x): + return x * 2. + + class ParentNetwork(network.Network): + + def __init__(self, name=None): + super(ParentNetwork, self).__init__(name=name) + self.l1 = self.track_layer(ChildNetwork()) + + def call(self, x): + return self.l1(x) + + one = constant_op.constant([[1.0]]) + gc.disable() + gc.collect() + previous_gc_debug_flags = gc.get_debug() + gc.set_debug(gc.DEBUG_SAVEALL) + preexisting = len(gc.garbage) + net = ParentNetwork() + net(one) + del net + gc.collect() + # There should be no additional garbage requiring collection. + self.assertEqual(preexisting, len(gc.garbage)) + gc.set_debug(previous_gc_debug_flags) + gc.enable() + + @test_util.run_in_graph_and_eager_modes() + def testAnonymousNoNameInitially(self): + net = MyNetwork() + with self.assertRaisesRegexp(ValueError, "does not yet have a final name"): + net.name # pylint: disable=pointless-statement + + @test_util.run_in_graph_and_eager_modes() + def testExplicitHasNameInitially(self): + net = MyNetwork(name="abcd") + self.assertEqual("abcd", net.name) + + @test_util.run_in_graph_and_eager_modes() + def testUsingResourceVariables(self): + net = MyNetwork() + net(constant_op.constant([[0.]])) + self.assertIsInstance(net.trainable_weights[0], + resource_variable_ops.ResourceVariable) + + @test_util.run_in_graph_and_eager_modes() + def testDuplicateNameError(self): + one = constant_op.constant([[1.]]) + net = MyNetwork(name="foo") + net(one) + with self.assertRaisesRegexp( + ValueError, "named 'foo' already exists"): + net1 = MyNetwork(name="foo") + net1(one) + + @test_util.run_in_graph_and_eager_modes() + def testWrappingInVariableScope(self): + with variable_scope.variable_scope("outside_scope"): + net = MyNetwork() + one = constant_op.constant([[1.]]) + with self.assertRaisesRegexp( + ValueError, + ("Creating Networks inside named variable_scopes is currently not " + "supported")): + net(one) + # Alternatively, we could re-name the Network to match the variable_scope: + # self.assertEqual("outside_scope/my_network_1", net.name) + # self.assertStartsWith( + # expected_start="outside_scope/my_network_1/dense/", + # actual=net.trainable_weights[0].name) + + @test_util.run_in_graph_and_eager_modes() + def testLayerNamesRespected(self): + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__() + self.first = self.track_layer( + core.Dense(1, use_bias=False, name="explicit_name")) + + def call(self, x): + return self.first(x) + + one = constant_op.constant([[1.]]) + net = ParentNetwork() + net(one) + self.assertStartsWith(expected_start="parent_network_1/explicit_name/", + actual=net.trainable_weights[0].name) + self.assertEqual("explicit_name", net.first.name) + + @test_util.run_in_graph_and_eager_modes() + def testWrappingInAnonymousVariableScope(self): + # Named outside variable_scopes are not supported at the moment. However, + # blank-named top level variable scopes do not change variable names, and so + # can be used to set the properties of Network variables. + was_called = [False] + def _custom_getter(getter, *args, **kwargs): + was_called[0] = True + return getter(*args, **kwargs) + with variable_scope.variable_scope("", custom_getter=_custom_getter): + net = MyNetwork() + one = constant_op.constant([[1.]]) + net(one) + self.assertTrue(was_called[0]) + + @test_util.run_in_graph_and_eager_modes() + def testReasonableSlashError(self): + with self.assertRaisesRegexp( + ValueError, "not allowed in Network names"): + MyNetwork(name="slash/slash") + + @test_util.run_in_graph_and_eager_modes() + def testNoVariableScopeNames(self): + with self.assertRaisesRegexp( + ValueError, "VariableScopes are not valid Network names"): + with variable_scope.variable_scope("some_scope") as vs: + MyNetwork(name=vs) + + @test_util.run_in_graph_and_eager_modes() + def testVariableScopeNameCollision(self): + with variable_scope.variable_scope("abcd"): + pass + with self.assertRaisesRegexp( + ValueError, "or a variable_scope was created with this name"): + net = MyNetwork(name="abcd") + one = constant_op.constant([[1.]]) + net(one) + + @test_util.run_in_graph_and_eager_modes() + def testNetworkVariablesDoNotInterfere(self): + core.Dense(1, use_bias=True) # Should not interfere with naming. + net1 = MyNetwork() + net2 = MyNetwork() + one = constant_op.constant([[1.]]) + net1(one) + net2(one) + # Layer names typically are globally unique rather than being unique within + # the scope of their first use. However, within a Network they must be named + # locally so that previous Layer consutrciton does not interfere with + # variable naming (e.g. add a Layer construction before the Network, + # suddenly your previously saved checkpoint is incompatible). + self.assertEqual("dense_1", net1.l1.name) + self.assertEqual("dense_1", net2.l1.name) + self.evaluate(net1.trainable_weights[0].assign([[1.]])) + self.evaluate(net2.trainable_weights[0].assign([[2.]])) + self.assertEqual(2., self.evaluate(net2.trainable_weights[0])) + self.assertEqual(1., self.evaluate(net1.trainable_weights[0])) + self.assertStartsWith(expected_start="my_network_1/dense_1/", + actual=net1.trainable_weights[0].name) + self.assertStartsWith(expected_start="my_network_2/dense_1/", + actual=net2.trainable_weights[0].name) + + @test_util.run_in_graph_and_eager_modes() + def testNestableAnonymous(self): + + # The case where no explicit names are specified. We make up unique names, + # and these should match the variable names. + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__() + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.second(self.first(x)) + + one = constant_op.constant([[1.]]) + net = ParentNetwork() + net(one) + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + actual=net.trainable_weights[0].name) + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + actual=net.first.trainable_weights[0].name) + self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + actual=net.trainable_weights[1].name) + self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + actual=net.second.trainable_weights[0].name) + self.assertEqual("parent_network_1", net.name) + self.assertEqual("my_network_1", net.first.name) + self.assertEqual("my_network_2", net.second.name) + + net2 = ParentNetwork() + net2(one) + self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + actual=net2.trainable_weights[0].name) + self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + actual=net2.first.trainable_weights[0].name) + self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + actual=net2.trainable_weights[1].name) + self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + actual=net2.second.trainable_weights[0].name) + self.assertEqual("parent_network_2", net2.name) + self.assertEqual("my_network_1", net2.first.name) + self.assertEqual("my_network_2", net2.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testNestableExplicit(self): + + # We have explicit network names and everything is globally unique. + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__(name="unique_parent_name") + self.first = self.track_layer( + MyNetwork(name="first_unique_child_name")) + self.second = self.track_layer( + MyNetwork(name="second_unique_child_name")) + + def call(self, x): + return self.second(self.first(x)) + + one = constant_op.constant([[1.]]) + net = ParentNetwork() + net(one) + self.assertStartsWith( + expected_start="unique_parent_name/first_unique_child_name/dense", + actual=net.trainable_weights[0].name) + self.assertStartsWith( + expected_start="unique_parent_name/second_unique_child_name/dense", + actual=net.trainable_weights[1].name) + self.assertEqual("unique_parent_name", net.name) + self.assertEqual("first_unique_child_name", net.first.name) + self.assertEqual("second_unique_child_name", net.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testLayerNetworkNameInteractions(self): + + # Same base name as core.Dense; Networks and non-Network Layers with the + # same base name should use the same numbering system. + class Dense(network.Network): + + def __init__(self): + super(Dense, self).__init__() + self.first = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.first(x) + + class MixedLayerNetwork(network.Network): + + def __init__(self): + super(MixedLayerNetwork, self).__init__() + self.first = self.track_layer(core.Dense(1, use_bias=False)) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + self.third = self.track_layer(Dense()) + self.fourth = self.track_layer(core.Dense(1, use_bias=False)) + self.fifth = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.fifth(self.fourth(self.third(self.second(self.first(x))))) + + one = constant_op.constant([[1.]]) + net = MixedLayerNetwork() + net(one) + self.assertEqual("dense_1", net.first.name) + self.assertEqual("dense_2", net.second.name) + self.assertEqual("dense_3", net.third.name) + self.assertEqual("dense_4", net.fourth.name) + self.assertEqual("dense_5", net.fifth.name) + # Note that this is _not_ the default naming behavior for Layers. Layers + # which are added to Networks follow Network variable naming conventions + # (i.e. variable names = network name unless variable sharing). Nested + # Layers revert to Layer behavior. + self.assertStartsWith(expected_start="mixed_layer_network_1/dense_1/", + actual=net.trainable_weights[0].name) + self.assertStartsWith(expected_start="mixed_layer_network_1/dense_2/", + actual=net.trainable_weights[1].name) + self.assertStartsWith(expected_start="mixed_layer_network_1/dense_3/", + actual=net.trainable_weights[2].name) + self.assertStartsWith(expected_start="mixed_layer_network_1/dense_4/", + actual=net.trainable_weights[3].name) + self.assertStartsWith(expected_start="mixed_layer_network_1/dense_5/", + actual=net.trainable_weights[4].name) + self.assertEqual("mixed_layer_network_1", net.name) + + @test_util.run_in_graph_and_eager_modes() + def testNestableExplicitCollisions(self): + + # We have explicit network names and they are unique within the layer + # they're added to. + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__(name="nonunique_name") + self.first = self.track_layer( + MyNetwork(name="nonunique_name")) + self.second = self.track_layer( + MyNetwork(name="second_unique_child_name")) + + def call(self, x): + return self.second(self.first(x)) + + one = constant_op.constant([[1.]]) + net = ParentNetwork() + net(one) + self.assertStartsWith( + expected_start="nonunique_name/nonunique_name/dense", + actual=net.trainable_weights[0].name) + self.assertStartsWith( + expected_start="nonunique_name/second_unique_child_name/dense", + actual=net.trainable_weights[1].name) + self.assertEqual("nonunique_name", net.name) + self.assertEqual("nonunique_name", net.first.name) + self.assertEqual("second_unique_child_name", net.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testNestableExplicitWithAnonymousParent(self): + + # A parent network is instantiated multiple times with explicitly named + # children. We shouldn't throw any name errors. + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__() + self.first = self.track_layer( + MyNetwork(name="first_unique_child_name")) + self.second = self.track_layer( + MyNetwork(name="second_unique_child_name")) + + def call(self, x): + return self.second(self.first(x)) + + one = constant_op.constant([[1.]]) + net = ParentNetwork() + net(one) + self.assertStartsWith( + expected_start="parent_network_1/first_unique_child_name/dense_1/", + actual=net.trainable_weights[0].name) + self.assertStartsWith( + expected_start="parent_network_1/second_unique_child_name/dense_1/", + actual=net.trainable_weights[1].name) + self.assertEqual("parent_network_1", net.name) + self.assertEqual("first_unique_child_name", net.first.name) + self.assertEqual("second_unique_child_name", net.second.name) + + net2 = ParentNetwork() + net2(one) + self.assertStartsWith( + expected_start="parent_network_2/first_unique_child_name/dense", + actual=net2.trainable_weights[0].name) + self.assertStartsWith( + expected_start="parent_network_2/second_unique_child_name/dense", + actual=net2.trainable_weights[1].name) + self.assertEqual("parent_network_2", net2.name) + self.assertEqual("first_unique_child_name", net2.first.name) + self.assertEqual("second_unique_child_name", net2.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testNestableExplicitSameLayerCollisions(self): + + # We have explicit network names and they are _not_ unique within the layer + # they're added to. Error. + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__(name="unique_parent_name") + self.first = self.track_layer(MyNetwork(name="nonunique_name")) + self.second = self.track_layer(MyNetwork(name="nonunique_name")) + + def call(self, x): + return self.second(self.first(x)) + + with self.assertRaisesRegexp(ValueError, "nonunique_name"): + ParentNetwork() + + @test_util.run_in_graph_and_eager_modes() + def testAnonymousVariableSharing(self): + + # Two "owned" Networks + class FirstParentNetwork(network.Network): + + def __init__(self): + super(FirstParentNetwork, self).__init__() + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.second(self.first(x)) + + one = constant_op.constant([[1.]]) + net = FirstParentNetwork() + net(one) + + # One Network shared with FirstParentNetwork, one owned Network. Same name, + # but this is OK because only one is owned. This name collision is + # avoidable; we could have looked at the base_name of the non-owned Network + # and incremented our naming based on that. + class SecondParentNetwork(network.Network): + + def __init__(self): + super(SecondParentNetwork, self).__init__() + self.first = self.track_layer(net.first) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.second(self.first(x)) + + net2 = SecondParentNetwork() + net2(one) + + self.assertStartsWith( + expected_start="first_parent_network_1/my_network_1/dense_1/", + actual=net2.trainable_weights[0].name) + self.assertStartsWith( + expected_start="second_parent_network_1/my_network_1/dense_1/", + actual=net2.trainable_weights[1].name) + self.assertEqual("second_parent_network_1", net2.name) + self.assertTrue(net2.first is net.first) + self.assertEqual("my_network_1", net2.first.name) + self.assertEqual("my_network_1", net2.second.name) + + # No name collision; the owned Network is added first and has a different + # name than the shared Network. + class ThirdParentNetwork(network.Network): + + def __init__(self): + super(ThirdParentNetwork, self).__init__() + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(net.second) + + def call(self, x): + return self.second(self.first(x)) + + net3 = ThirdParentNetwork() + net3(one) + + self.assertStartsWith( + expected_start="third_parent_network_1/my_network_1/dense", + actual=net3.trainable_weights[0].name) + self.assertStartsWith( + expected_start="first_parent_network_1/my_network_2/dense", + actual=net3.trainable_weights[1].name) + self.assertEqual("third_parent_network_1", net3.name) + self.assertTrue(net3.second is net.second) + self.assertEqual("my_network_1", net3.first.name) + self.assertEqual("my_network_2", net3.second.name) + + # "Unavoidable" same-name Layer. The owned name is added first (fixed), then + # a shared Network is added with the same name. + class FourthParentNetwork(network.Network): + + def __init__(self): + super(FourthParentNetwork, self).__init__() + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(net.first) + + def call(self, x): + return self.second(self.first(x)) + + net4 = FourthParentNetwork() + net4(one) + + self.assertStartsWith( + expected_start="fourth_parent_network_1/my_network_1/dense_1/", + actual=net4.trainable_weights[0].name) + self.assertStartsWith( + expected_start="first_parent_network_1/my_network_1/dense_1/", + actual=net4.trainable_weights[1].name) + self.assertEqual("fourth_parent_network_1", net4.name) + self.assertTrue(net4.second is net.first) + self.assertEqual("my_network_1", net4.first.name) + self.assertEqual("my_network_1", net4.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testRecursiveLayerRenaming(self): + core.Dense(1) # Under default Layer naming, would change subsequent names. + + class NetworkWithLayerChildren(network.Network): + + def __init__(self): + super(NetworkWithLayerChildren, self).__init__() + self.first = self.track_layer(core.Dense(1, use_bias=False)) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__() + self.first = self.track_layer(NetworkWithLayerChildren()) + self.second = self.track_layer(NetworkWithLayerChildren()) + + def call(self, x): + return self.second(self.first(x)) + + net = ParentNetwork() + one = constant_op.constant([[1.]]) + net(one) + + self.assertStartsWith( + expected_start=("parent_network_1/network_with_layer_children_1/" + "dense_1/"), + actual=net.trainable_weights[0].name) + self.assertStartsWith( + expected_start=("parent_network_1/network_with_layer_children_1/" + "dense_2/"), + actual=net.trainable_weights[1].name) + self.assertStartsWith( + expected_start=("parent_network_1/network_with_layer_children_2/" + "dense_1/"), + actual=net.trainable_weights[2].name) + self.assertStartsWith( + expected_start=("parent_network_1/network_with_layer_children_2/" + "dense_2/"), + actual=net.trainable_weights[3].name) + self.assertEqual("parent_network_1", net.name) + self.assertEqual("network_with_layer_children_1", net.first.name) + self.assertEqual("network_with_layer_children_2", net.second.name) + self.assertEqual("dense_1", net.first.first.name) + self.assertEqual("dense_2", net.first.second.name) + self.assertEqual("dense_1", net.second.first.name) + self.assertEqual("dense_2", net.second.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testCallInDifferentOrderThanConstruct(self): + shared_network = MyNetwork() + + class FirstNetwork(network.Network): + + def __init__(self): + super(FirstNetwork, self).__init__() + self.first = self.track_layer(shared_network) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.second(self.first(x)) + + class SecondNetwork(network.Network): + + def __init__(self): + super(SecondNetwork, self).__init__() + self.first = self.track_layer(shared_network) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.second(self.first(x)) + + net1 = FirstNetwork() + net2 = SecondNetwork() + + one = constant_op.constant([[1.]]) + net2(one) + net1(one) + + self.assertStartsWith( + expected_start="first_network_1/my_network_1/dense_1/", + actual=net1.trainable_weights[0].name) + self.assertStartsWith( + expected_start="first_network_1/my_network_2/dense_1/", + actual=net1.trainable_weights[1].name) + self.assertStartsWith( + expected_start="first_network_1/my_network_1/dense_1/", + actual=net2.trainable_weights[0].name) + self.assertStartsWith( + expected_start="second_network_1/my_network_1/dense_1/", + actual=net2.trainable_weights[1].name) + self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) + self.assertEqual("first_network_1", net1.name) + self.assertEqual("my_network_1", net1.first.name) + self.assertEqual("my_network_2", net1.second.name) + self.assertTrue(net2.first is net1.first) + self.assertEqual("my_network_1", net2.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testLayerCallInDifferentOrderThanConstruct(self): + # Same idea as testCallInDifferentOrderThanConstruct, but this time with a + # non-Network Layer shared between two Networks rather than a + # Network. Naming should follow the same rules. + shared_layer = core.Dense(1, use_bias=False) + + class FirstNetwork(network.Network): + + def __init__(self): + super(FirstNetwork, self).__init__() + self.first = self.track_layer(shared_layer) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + class SecondNetwork(network.Network): + + def __init__(self): + super(SecondNetwork, self).__init__() + self.first = self.track_layer(shared_layer) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + net1 = FirstNetwork() + net2 = SecondNetwork() + + one = constant_op.constant([[1.]]) + net2(one) + net1(one) + + self.assertStartsWith( + expected_start="first_network_1/dense_1/", + actual=net1.trainable_weights[0].name) + self.assertStartsWith( + expected_start="first_network_1/dense_2/", + actual=net1.trainable_weights[1].name) + self.assertStartsWith( + expected_start="first_network_1/dense_1/", + actual=net2.trainable_weights[0].name) + self.assertStartsWith( + expected_start="second_network_1/dense_1/", + actual=net2.trainable_weights[1].name) + self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) + self.assertEqual("first_network_1", net1.name) + self.assertEqual("dense_1", net1.first.name) + self.assertEqual("dense_2", net1.second.name) + self.assertTrue(net2.first is net1.first) + self.assertEqual("dense_1", net2.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testLayerAlreadyBuilt(self): + one = constant_op.constant([[1.]]) + core.Dense(1, use_bias=False) # pre-built layers use global naming + one = constant_op.constant([[1.]]) + core.Dense(1, use_bias=False)(one) + shared_layer = core.Dense(1, use_bias=False) + shared_layer(one) + + class FirstNetwork(network.Network): + + def __init__(self): + super(FirstNetwork, self).__init__() + self.first = self.track_layer(shared_layer) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + net = FirstNetwork() + net(one) + + self.assertStartsWith( + expected_start="dense_1/", # Pre-built layers have variable names which + # do not match their layer names. + actual=net.trainable_weights[0].name) + self.assertStartsWith( + expected_start="first_network_1/dense_1/", + actual=net.trainable_weights[1].name) + self.assertTrue( + net.trainable_weights[0] is shared_layer.trainable_weights[0]) + self.assertEqual("first_network_1", net.name) + self.assertEqual("dense_3", net.first.name) + self.assertEqual("dense_1", net.second.name) + + +class SequentialTest(test.TestCase): + + def testTwoLayers(self): + # Create a sequential network with one layer. + net = network.Sequential([core.Dense(1, use_bias=False)]) + + # Set that layer's weights so it multiplies by 3 + l1 = net.get_layer(index=0) + net(constant_op.constant([[2.0]])) # Create l1's variables + self.assertEqual(1, len(l1.trainable_variables)) + l1.trainable_variables[0].assign([[3.0]]) + self.assertEqual(21.0, net(constant_op.constant([[7.0]])).numpy()) + + # Add a second layer to the network. + l2 = core.Dense(1, use_bias=False) + net.add(l2) + + # Set the second layer's weights so it multiplies by 11 + net(constant_op.constant([[2.0]])) # Create l2's variables + self.assertEqual(1, len(l2.trainable_variables)) + l2.trainable_variables[0].assign([[11.0]]) + self.assertEqual(231.0, net(constant_op.constant([[7.0]])).numpy()) + + def testFunctions(self): + # Create a sequential network with one function. + net = network.Sequential([nn_ops.relu]) + two = constant_op.constant(2.0) + self.assertEqual(2.0, net(two).numpy()) + self.assertEqual(0.0, net(-two).numpy()) + # Add a second function. + net.add(math_ops.negative) + self.assertEqual(-2.0, net(two).numpy()) + + def testTrainingLayer(self): + net = network.Sequential([core.Dropout(0.99999)]) + two = constant_op.constant(2.0) + self.assertEqual(2.0, net(two).numpy()) + self.assertEqual(2.0, net(two, training=False).numpy()) + for _ in range(20): + with_dropout = net(two, training=True).numpy() + self.assertIn(with_dropout, [0.0, 2.0]) + if with_dropout == 0.0: + return + # Should only fail spuriously 1 in 10^100 runs. + self.fail("Didn't see dropout happen after 20 tries.") + + def testTrainingFunction(self): + # Output depends on value of "training". + def add_training(input_value, training=None): + if training is None: + return input_value + elif training: + return input_value + 1 + return input_value - 1 + + # Passing a "training" argument to double would cause an error. + def double(input_value): + return 2 * input_value + + net = network.Sequential([add_training, double]) + two = constant_op.constant(2) + self.assertEqual(4, net(two).numpy()) + self.assertEqual(2, net(two, training=False).numpy()) + self.assertEqual(6, net(two, training=True).numpy()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index d289b83f537acc76fefa3343115a76c13ba7451b..e0a20d2485e831b1841991596b91429c6eaa2854 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -19,31 +19,34 @@ from __future__ import print_function import contextlib +from tensorflow.python.eager import context from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import adam as _adam from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver as _saver def _init_from_checkpoint(self, *args, **kwargs): """Overrides default init by loading value from checkpoint.""" - self.old_init(*args, **kwargs) # pylint: disable=protected-access - if self._shared_name not in self.ckpt_var_cache: + self._old_init(*args, **kwargs) + ckpt_name = self._map_func(self._shared_name) + if ckpt_name not in self._ckpt_var_cache: raise errors.NotFoundError(None, None, - "%s not found in checkpoint" % self._shared_name) + "%s not found in checkpoint" % ckpt_name) - val = self.ckpt_var_cache[self._shared_name] + val = self._ckpt_var_cache.get(ckpt_name, None) if val is not None: - self.assign(self.ckpt_var_cache[self._shared_name]) + self.assign(val) # Avoid assigning for the second time. - self.ckpt_var_cache[self._shared_name] = None + self._ckpt_var_cache[ckpt_name] = None # pylint: enable=protected-access @contextlib.contextmanager -def restore_variables_on_create(save_path): +def restore_variables_on_create(save_path, map_func=None): """ContextManager that restores variables on creation. When save_path is None (e.g. No checkpoint), does nothing. @@ -58,26 +61,45 @@ def restore_variables_on_create(save_path): Args: save_path: The checkpoint file prefix. + map_func: A function that given the variable name as argument + and returns a variable name in checkpoint for restore. If + None, use the variable with the same name in checkpoint to restore. + It's an error that the mapped variable name doesn't exist in + checkpoint. Yields: Nothing. Raises: NotFoundError: If the variable is not found in checkpoint. + ValueError: If not used in eager mode or map_func is not callable. """ + if context.in_graph_mode(): + raise ValueError( + "Currently, restore_variables_on_create can only be used with " + "eager execution enabled.") if save_path: + if map_func is None: + map_func_wrapper = lambda self, x: x + else: + if not callable(map_func): + raise ValueError("map_func must be callaled.") + map_func_wrapper = lambda self, x: map_func(x) + ckpt_var_cache = dict() reader = checkpoint_utils.load_checkpoint(save_path) for k, _ in checkpoint_utils.list_variables(save_path): ckpt_var_cache[k] = reader.get_tensor(k) - old_init = getattr( - resource_variable_ops.ResourceVariable, "_init_from_args", None) + old_init = getattr(resource_variable_ops.ResourceVariable, + "_init_from_args", None) assert old_init, "ResourceVariable misses _init_from_args method." setattr(resource_variable_ops.ResourceVariable, "_init_from_args", _init_from_checkpoint) - setattr(resource_variable_ops.ResourceVariable, "old_init", old_init) - setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", + setattr(resource_variable_ops.ResourceVariable, "_old_init", old_init) + setattr(resource_variable_ops.ResourceVariable, "_map_func", + map_func_wrapper) + setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache", ckpt_var_cache) try: yield @@ -87,43 +109,82 @@ def restore_variables_on_create(save_path): if save_path: setattr(resource_variable_ops.ResourceVariable, "_init_from_args", old_init) - setattr(resource_variable_ops.ResourceVariable, "old_init", None) - setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", None) + setattr(resource_variable_ops.ResourceVariable, "_old_init", None) + setattr(resource_variable_ops.ResourceVariable, "_map_func", None) + setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache", None) class Saver(object): - """A simple tf.train.Saver adapter for eager mode. - - save and restore API are similar to the tf.train.Saver, except that - session is not needed. - - Args: - var_list: A list of variables. + """A tf.train.Saver adapter for use when eager execution is enabled. """ def __init__(self, var_list): + """A tf.train.Saver adapter for use when eager execution is enabled. + + The API, and on-disk format, mimic tf.train.Saver except that no + Session is needed. + + Args: + var_list: The list of variables that will be saved and restored. Either a + list of `tfe.Variable` objects, or a dictionary mapping names to + `tfe.Variable` objects. + + Raises: + RuntimeError: if invoked when eager execution has not been enabled. + """ + if context.in_graph_mode(): + raise RuntimeError("tfe.Saver can only be used when eager " + "execution is enabled. Use tf.train.Saver when " + "building graphs.") self._saver = _saver.Saver(var_list=var_list) - def save(self, save_path, global_step=None): + def save(self, file_prefix, global_step=None): """Saves variables. Args: - save_path: See save method in tf.train.Saver. - global_step: See save method in tf.train.Saver. + file_prefix: Path prefix of files created for the checkpoint. + global_step: If provided the global step number is appended to file_prefix + to create the checkpoint filename. The optional argument can be a + Tensor, a Variable, or an integer. Returns: - See save method in tf.train.Saver. + A string: prefix of filenames created for the checkpoint. This may be + an extension of file_prefix that is suitable to pass as an argument + to a subsequent call to `restore()`. """ with ops.device("/device:CPU:0"): - return self._saver.save(None, save_path, write_meta_graph=False, - global_step=global_step) + return self._saver.save( + None, file_prefix, write_meta_graph=False, global_step=global_step) - def restore(self, save_path): + def restore(self, file_prefix): """Restores previously saved variables. Args: - save_path: See restore method in tf.train.Saver. + file_prefix: Path prefix where parameters were previously saved. + Typically obtained from a previous `save()` call, or from + @{tf.train.latest_checkpoint}. """ with ops.device("/device:CPU:0"): - self._saver.restore(None, save_path) + self._saver.restore(None, file_prefix) + + +def get_optimizer_variables(optimizer): + """Returns a list of variables for the given `tf.train.Optimizer`. + Args: + optimizer: An instance of `tf.train.Optimizer` which has created variables + (typically after a call to `Optimizer.minimize`). + Returns: + A list of variables which have been created by the `Optimizer`. Currently + returns all variables even if they were not created in the default graph, + but this behavior may change. + """ + variables = [] + # pylint: disable=protected-access + for _, variable_dict in optimizer._slots.items(): + for _, slot_for_variable in variable_dict.items(): + variables.append(slot_for_variable) + if isinstance(optimizer, _adam.AdamOptimizer): + variables.append(optimizer._beta1_power) + variables.append(optimizer._beta2_power) + return variables diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index cdec50ebd787f0cfffedee6391155e97f77c63d8..abc7e3690c76c4446bce6b945325f1ca15ef1c8b 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -21,11 +21,19 @@ import os from tensorflow.contrib.eager.python import saver as _saver from tensorflow.python.eager import context +from tensorflow.python.eager import graph_callable +from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.platform import test +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import adam +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import momentum +from tensorflow.python.training import rmsprop class SaverTest(test.TestCase): @@ -34,7 +42,7 @@ class SaverTest(test.TestCase): return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0' def testBasics(self): - with context.eager_mode(), ops.device(self._dev()): + with ops.device(self._dev()): v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') def model(): return array_ops.constant(2.0) * v1 @@ -50,8 +58,76 @@ class SaverTest(test.TestCase): saver.restore(ckpt_prefix) self.assertEqual(v1.read_value().numpy(), 1.0) + def testSameNameNoClobbering(self): + with ops.device(self._dev()): + # Note that this test purposefully uses Graphs rather than + # IsolateTest. Users are more likely to accidentally create the same + # variable name this way. + first_graph = ops.Graph() + with first_graph.as_default(): + v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1') + with ops.Graph().as_default(): + v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1') + saver = _saver.Saver([v1_first_graph, v1_second_graph]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + with self.assertRaisesRegexp(ValueError, 'v1'): + saver.save(ckpt_prefix) + + def testDifferentGraphError(self): + with ops.device(self._dev()): + with ops.Graph().as_default(): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + with ops.Graph().as_default(): + saver = _saver.Saver([v1]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + with self.assertRaisesRegexp(ValueError, 'Graph'): + saver.save(ckpt_prefix) + + def testSameObjectOK(self): + with ops.device(self._dev()): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + # While different objects with the same shared_name are not good, passing + # in the same object multiple times is fine. + saver = _saver.Saver([v1, v1]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + saver.save(ckpt_prefix) + + def testSaveByDict(self): + with ops.device(self._dev()): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + v2 = resource_variable_ops.ResourceVariable(1.0, name='v2') + def model(): + return array_ops.constant(2.0) * v1 * v2 + + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + + # Save the variables under different names. + _ = model() + saver = _saver.Saver({'ckpt/v1': v1, 'ckpt/v2': v2}) + saver.save(ckpt_prefix) + v1.assign(2.0) + v2.assign(2.0) + self.assertEqual(v1.read_value().numpy(), 2.0) + self.assertEqual(v2.read_value().numpy(), 2.0) + # Can still restore it. + saver.restore(ckpt_prefix) + self.assertEqual(v1.read_value().numpy(), 1.0) + self.assertEqual(v1.read_value().numpy(), 1.0) + # However, cannot restore it with default name. + with self.assertRaisesOpError('not found in checkpoint'): + saver = _saver.Saver([v1, v2]).restore(ckpt_prefix) + + # Can specify which variable in ckpt to restore to which variable. + def map_func(x): + return {'v3': 'ckpt/v1', 'v4': 'ckpt/v2'}.get(x, x) + with _saver.restore_variables_on_create(ckpt_prefix, map_func): + v3 = resource_variable_ops.ResourceVariable(2.0, name='v3') + v4 = resource_variable_ops.ResourceVariable(2.0, name='v4') + self.assertEqual(v3.read_value().numpy(), 1.0) + self.assertEqual(v4.read_value().numpy(), 1.0) + def testRestoreOnCreate(self): - with context.eager_mode(), ops.device(self._dev()): + with ops.device(self._dev()): def model(init_val): v1 = resource_variable_ops.ResourceVariable(init_val, name='v1') return array_ops.constant(1.0) * v1, v1 @@ -67,12 +143,9 @@ class SaverTest(test.TestCase): # Value is from checkpoint, but not from argument. ret, _ = model(2.0) self.assertEqual(ret.numpy(), 1.0) - # Create it a second time won't re-assign the checkpoint value. - v1_2 = resource_variable_ops.ResourceVariable(3.0, name='v1') - self.assertEqual(v1_2.read_value().numpy(), 3.0) def testRestoreNotFound(self): - with context.eager_mode(), ops.device(self._dev()): + with ops.device(self._dev()): def model(v): return array_ops.constant(1.0) * v @@ -87,6 +160,90 @@ class SaverTest(test.TestCase): with _saver.restore_variables_on_create(ckpt_prefix): _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2')) + def testSaveRestoreGraphCallable(self): + with ops.device(self._dev()): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + # Default 2 + 0 = 2 + self.assertEqual( + 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # Save the variable value 0. + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + _saver.Saver(model.variables).save(ckpt_prefix) + + # update variable to 1, so that 2 + 1 = 3 + model.variables[0].assign(1.) + self.assertEqual( + 3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # load the variable value 0, so that 2 + 0 = 2 + _saver.Saver(model.variables).restore(ckpt_prefix) + self.assertEqual( + 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # update checkpoint variable to 1 and memory value to 2. + model.variables[0].assign(1.) + _saver.Saver(model.variables).save(ckpt_prefix) + model.variables[0].assign(2.) + self.assertEqual( + 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # reset the graph and reload on create, so that 1 + 2 = 3 + with ops.Graph().as_default(): + with _saver.restore_variables_on_create(ckpt_prefix): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model2(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + self.assertEqual( + 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + +class GetOptimizerTests(test.TestCase): + + def _optimizer_test_template(self, optimizer): + """Checks save and restore. Returns the optimizer variables.""" + v = resource_variable_ops.ResourceVariable([[2., 3.]], name='v') + loss_fn = lambda: v[0, 0] ** 2 + v[0, 1] ** 2 + optimizer.minimize(loss_fn) + optimizer_variables = _saver.get_optimizer_variables(optimizer) + saver = _saver.Saver(optimizer_variables + [v]) + checkpoint_path = saver.save(self.get_temp_dir()) + optimizer.minimize(loss_fn) + after_first_minimize = v.numpy() + # After we restore, the next step should be exactly the same as the one we + # just did. + saver.restore(checkpoint_path) + optimizer.minimize(loss_fn) + self.assertAllEqual(after_first_minimize, v.numpy()) + return optimizer_variables + + def testAdam(self): + optimizer = adam.AdamOptimizer(0.1) + self._optimizer_test_template(optimizer) + + def testGradientDescent(self): + optimizer = gradient_descent.GradientDescentOptimizer(0.02) + self.assertEqual(0, len(self._optimizer_test_template(optimizer))) + + def testMomentum(self): + optimizer = momentum.MomentumOptimizer( + learning_rate=0.03, + momentum=0.5) + self._optimizer_test_template(optimizer) + + def testRMSProp(self): + optimizer = rmsprop.RMSPropOptimizer(0.01) + self._optimizer_test_template(optimizer) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/eager/python/summary_writer.py b/tensorflow/contrib/eager/python/summary_writer.py index 39993558e33d9f88c9f642db2273fb81fd7be9e9..5a698b92c6a04992692d21ead25e91cd6c5d7e78 100644 --- a/tensorflow/contrib/eager/python/summary_writer.py +++ b/tensorflow/contrib/eager/python/summary_writer.py @@ -32,9 +32,9 @@ from tensorflow.python.ops import summary_op_util from tensorflow.python.ops import variable_scope -def _maybe_as_cpu_tensor(v): +def _maybe_cpu(v): if isinstance(v, (ops.EagerTensor, ops.Tensor)): - return v.as_cpu_tensor() + return v.cpu() else: return v @@ -161,9 +161,9 @@ class SummaryWriter(object): gen_summary_ops.write_summary( self._resource, self._update_global_step_tensor(), - _maybe_as_cpu_tensor(tensor), + _maybe_cpu(tensor), tag, - _maybe_as_cpu_tensor(metadata), + _maybe_cpu(metadata), name=scope) def scalar(self, name, tensor, family=None): @@ -185,7 +185,7 @@ class SummaryWriter(object): name, family, values=[tensor]) as (tag, scope): gen_summary_ops.write_scalar_summary( self._resource, self._update_global_step_tensor(), - tag, _maybe_as_cpu_tensor(tensor), name=scope) + tag, _maybe_cpu(tensor), name=scope) def histogram(self, name, tensor, family=None): """Write a histogram summary. @@ -203,7 +203,7 @@ class SummaryWriter(object): name, family, values=[tensor]) as (tag, scope): gen_summary_ops.write_histogram_summary( self._resource, self._update_global_step_tensor(), - tag, _maybe_as_cpu_tensor(tensor), name=scope) + tag, _maybe_cpu(tensor), name=scope) def image(self, name, tensor, bad_color=None, max_images=3, family=None): """Write an image summary.""" @@ -214,7 +214,7 @@ class SummaryWriter(object): name, family, values=[tensor]) as (tag, scope): gen_summary_ops.write_image_summary( self._resource, self._update_global_step_tensor(), - tag, _maybe_as_cpu_tensor(tensor), bad_color_, max_images, + tag, _maybe_cpu(tensor), bad_color_, max_images, name=scope) def audio(self, name, tensor, sample_rate, max_outputs, family=None): @@ -238,7 +238,7 @@ class SummaryWriter(object): gen_summary_ops.write_audio_summary( self._resource, self._update_global_step_tensor(), tag, - _maybe_as_cpu_tensor(tensor), - sample_rate=_maybe_as_cpu_tensor(sample_rate), + _maybe_cpu(tensor), + sample_rate=_maybe_cpu(sample_rate), max_outputs=max_outputs, name=scope) diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 249aaebea24623421c39b4f41bf486ced6579b65..ab31893cd3e3459395be3360a0fc963684b744e6 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -18,6 +18,8 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. To use, at program startup, call `tfe.enable_eager_execution()`. +@@metrics + @@list_devices @@num_gpus @@ -26,6 +28,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@implicit_value_and_gradients @@gradients_function @@value_and_gradients_function +@@GradientTape @@enable_tracing @@flush_trace @@ -43,10 +46,22 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@seterr @@Iterator +@@Network @@Saver @@SummaryWriter @@restore_variables_on_create @@Variable +@@get_optimizer_variables + +@@in_eager_mode +@@in_graph_mode + +@@IsolateTest +@@run_test_in_graph_and_eager_modes + +@@DEVICE_PLACEMENT_EXPLICIT +@@DEVICE_PLACEMENT_WARN +@@DEVICE_PLACEMENT_SILENT """ from __future__ import absolute_import @@ -56,31 +71,42 @@ from __future__ import print_function # pylint:disable=g-bad-import-order,g-import-not-at-top,unused-import # +from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.eager.python.datasets import Iterator -from tensorflow.contrib.eager.python.saver import Saver +from tensorflow.contrib.eager.python.network import Network +from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create +from tensorflow.contrib.eager.python.saver import Saver from tensorflow.contrib.eager.python.summary_writer import SummaryWriter -from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.eager import backprop -from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager import function -from tensorflow.python.eager.context import enable_eager_execution +from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT +from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN +from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT +from tensorflow.python.eager.context import in_eager_mode +from tensorflow.python.eager.context import in_graph_mode from tensorflow.python.eager.context import list_devices from tensorflow.python.eager.context import num_gpus -from tensorflow.python.eager.context import run from tensorflow.python.eager.core import enable_tracing +from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr +from tensorflow.python.framework.ops import enable_eager_execution +from tensorflow.python.framework.ops import eager_run as run +from tensorflow.python.framework.test_util import IsolateTest +from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable +from tensorflow.python.util.all_util import remove_undocumented defun = function.defun implicit_gradients = backprop.implicit_grad implicit_value_and_gradients = backprop.implicit_val_and_grad gradients_function = backprop.gradients_function value_and_gradients_function = backprop.val_and_grad_function +GradientTape = backprop.GradientTape # pylint: disable=invalid-name remove_undocumented(__name__) diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 3d57a98a2ee068281b0934484994e113989e75ce..d8a38923a3f1e2426c3d96c2bb52380dc0d2e0c1 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -24,6 +24,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import numerics +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -39,6 +41,11 @@ class TFETest(test_util.TensorFlowTestCase): r'indices = 7 is not in \[0, 3\)'): array_ops.gather([0, 1, 2], 7) + def testVariableError(self): + with self.assertRaisesRegexp( + RuntimeError, r'Variable not supported in Eager mode'): + variables.Variable(initial_value=1.0) + def testGradients(self): def square(x): @@ -75,7 +82,7 @@ class TFETest(test_util.TensorFlowTestCase): self.skipTest('No GPUs available') # tf.Tensor.as_gpu_device() moves a tensor to GPU. - x = constant_op.constant([[1., 2.], [3., 4.]]).as_gpu_tensor() + x = constant_op.constant([[1., 2.], [3., 4.]]).gpu() # Alternatively, tf.device() as a context manager places tensors and # operations. with ops.device('gpu:0'): @@ -85,7 +92,7 @@ class TFETest(test_util.TensorFlowTestCase): reduction_indices = range(x.shape.ndims) m = math_ops.reduce_mean(x, reduction_indices) # m is on GPU, bring it back to CPU and compare. - self.assertEqual(3.5, m.as_cpu_tensor().numpy()) + self.assertEqual(3.5, m.cpu().numpy()) def testListDevices(self): # Expect at least one device. @@ -95,12 +102,11 @@ class TFETest(test_util.TensorFlowTestCase): devices = tfe.list_devices() self.assertEqual(len(devices) - 1, tfe.num_gpus()) - def testCallingEnableEagerExecutionMoreThanOnce(self): - # Note that eager.test.main() has already invoked enable_eager_exceution(). + def testAddCheckNumericsOpsRaisesError(self): with self.assertRaisesRegexp( - ValueError, r'Do not call tfe\.%s more than once in the same process' % - tfe.enable_eager_execution.__name__): - tfe.enable_eager_execution() + RuntimeError, + r'add_check_numerics_ops\(\) is not compatible with eager execution'): + numerics.add_check_numerics_ops() if __name__ == '__main__': diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 596f68844b3628d7101fe16e095db7b5160d5baf..79b166ac8882e0ad519ca222931e45b21223d77b 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -50,7 +50,10 @@ py_test( size = "small", srcs = ["python/estimator/dnn_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + "notsan", + ], deps = [ ":dnn", ":head", @@ -89,7 +92,7 @@ py_library( py_test( name = "extenders_test", - size = "small", + size = "medium", srcs = ["python/estimator/extenders_test.py"], srcs_version = "PY2AND3", deps = [ @@ -131,7 +134,9 @@ py_library( "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/estimator:util", "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model:signature_constants", ], ) @@ -146,9 +151,11 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", @@ -183,7 +190,8 @@ py_test( deps = [ ":logit_fns", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:session", "//tensorflow/python/estimator:model_fn", ], ) diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index 3e5eb3390f62141a82b51011d278d995b488b5e7..29c3c7358534f6e8ebbd31cbfcd7e34086d9b506 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops from tensorflow.python.training import optimizer as optimizer_lib -from tensorflow.python.util import tf_inspect + _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) @@ -317,9 +317,6 @@ class _TransformGradients(optimizer_lib.Optimizer): def _verify_metric_fn_args(metric_fn): args = set(estimator_util.fn_args(metric_fn)) - if tf_inspect.ismethod(metric_fn): - if 'self' in args: - args.remove('self') invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 9b14622ff6436efcf66dae311f773c8375b2cafa..189f098005b8926bfb30b723cc989cb854a5d77e 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import prediction_keys @@ -33,8 +34,11 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.losses import losses +from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + def multi_class_head(n_classes, weight_column=None, @@ -59,7 +63,7 @@ def multi_class_head(n_classes, `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for multi class classification. @@ -98,7 +102,7 @@ def binary_classification_head( `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for binary classification. @@ -129,7 +133,7 @@ def regression_head(weight_column=None, of the last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for linear regression. @@ -144,6 +148,7 @@ def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, + loss_fn=None, name=None): """Creates a `_Head` for multi-label classification. @@ -155,6 +160,12 @@ def multi_label_head(n_classes, multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer `SparseTensor` of class indices. + Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + `(labels, logits, features)` as arguments and returns unreduced loss with + shape `[batch_size, 1]`. `loss_fn` must support indicator `labels` with shape + `[batch_size, n_classes]`. Namely, the head applies `label_vocabulary` to the + input labels before passing them to `loss_fn`. + Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). @@ -171,8 +182,9 @@ def multi_label_head(n_classes, [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. + loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for multi-label classification. @@ -198,9 +210,11 @@ def multi_label_head(n_classes, raise ValueError( 'Length of label_vocabulary must be n_classes ({}). ' 'Given: {}'.format(n_classes, len(label_vocabulary))) + if loss_fn: + _validate_loss_fn_args(loss_fn) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, - label_vocabulary=label_vocabulary, name=name) + label_vocabulary=label_vocabulary, loss_fn=loss_fn, name=name) class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access @@ -211,11 +225,13 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weight_column=None, thresholds=None, label_vocabulary=None, + loss_fn=None, name=None): self._n_classes = n_classes self._weight_column = weight_column self._thresholds = thresholds self._label_vocabulary = label_vocabulary + self._loss_fn = loss_fn self._name = name @property @@ -227,6 +243,12 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access return self._n_classes def _process_labels(self, labels): + if labels is None: + raise ValueError( + 'You must provide a labels Tensor. Given: None. ' + 'Suggested troubleshooting steps: Check that your data contain ' + 'your label feature. Check that your input_fn properly parses and ' + 'returns labels.') if isinstance(labels, sparse_tensor.SparseTensor): if labels.dtype == dtypes.string: label_ids_values = lookup_ops.index_table_from_tensor( @@ -254,11 +276,19 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access def create_loss(self, features, mode, logits, labels): """See `Head`.""" - del mode, features # Unused for this head. + del mode # Unused for this head. processed_labels = self._process_labels(labels) - unweighted_loss = losses.sigmoid_cross_entropy( - multi_class_labels=processed_labels, logits=logits, - reduction=losses.Reduction.NONE) + if self._loss_fn: + unweighted_loss = _call_loss_fn( + loss_fn=self._loss_fn, labels=processed_labels, logits=logits, + features=features) + else: + unweighted_loss = losses.sigmoid_cross_entropy( + multi_class_labels=processed_labels, logits=logits, + reduction=losses.Reduction.NONE) + # Averages loss over classes. + unweighted_loss = math_ops.reduce_mean( + unweighted_loss, axis=-1, keep_dims=True) return head_lib.LossAndLabels( unweighted_loss=unweighted_loss, processed_labels=processed_labels) @@ -266,7 +296,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): """See `Head`.""" - with ops.name_scope('head'): + with ops.name_scope(self._name, 'head'): logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access # Predict. @@ -278,22 +308,25 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access pred_keys.PROBABILITIES: probabilities, } if mode == model_fn.ModeKeys.PREDICT: + classifier_output = head_lib._classification_output( # pylint:disable=protected-access + scores=probabilities, n_classes=self._n_classes, + label_vocabulary=self._label_vocabulary) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ - '': export_output.ClassificationOutput(scores=probabilities) + _DEFAULT_SERVING_KEY: classifier_output, + head_lib._CLASSIFY_SERVING_KEY: classifier_output, # pylint:disable=protected-access + head_lib._PREDICT_SERVING_KEY: ( # pylint:disable=protected-access + export_output.PredictOutput(predictions)) }) # Eval. unweighted_loss, processed_labels = self.create_loss( features=features, mode=mode, logits=logits, labels=labels) - # Averages loss over classes. - per_example_loss = math_ops.reduce_mean( - unweighted_loss, axis=-1, keep_dims=True) weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access training_loss = losses.compute_weighted_loss( - per_example_loss, weights=weights, reduction=losses.Reduction.SUM) + unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) if mode == model_fn.ModeKeys.EVAL: return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, @@ -303,7 +336,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access labels=processed_labels, probabilities=probabilities, weights=weights, - per_example_loss=per_example_loss)) + unweighted_loss=unweighted_loss)) # Train. if train_op_fn is None: @@ -324,16 +357,16 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access loss=training_loss, train_op=train_op_fn(training_loss)) - def _eval_metric_ops(self, labels, probabilities, weights, per_example_loss): + def _eval_metric_ops(self, labels, probabilities, weights, unweighted_loss): """Returns a dict of metrics for eval_metric_ops.""" with ops.name_scope( - None, 'metrics', [labels, probabilities, weights, per_example_loss]): + None, 'metrics', [labels, probabilities, weights, unweighted_loss]): keys = metric_keys.MetricKeys metric_ops = { # Estimator already adds a metric for loss. head_lib._summary_key(self._name, keys.LOSS_MEAN): # pylint:disable=protected-access metrics_lib.mean( - per_example_loss, weights=weights, name=keys.LOSS_MEAN), + unweighted_loss, weights=weights, name=keys.LOSS_MEAN), head_lib._summary_key(self._name, keys.AUC): # pylint:disable=protected-access metrics_lib.auc( labels=labels, predictions=probabilities, weights=weights, @@ -371,3 +404,53 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access threshold=threshold, name=recall_key)) return metric_ops + + +def _validate_loss_fn_args(loss_fn): + """Validates loss_fn arguments. + + Required arguments: labels, logits. + Optional arguments: features. + + Args: + loss_fn: The loss function. + Raises: + ValueError: If the signature is unexpected. + """ + loss_fn_args = util.fn_args(loss_fn) + for required_arg in ['labels', 'logits']: + if required_arg not in loss_fn_args: + raise ValueError( + 'loss_fn must contain argument: {}. ' + 'Given arguments: {}'.format(required_arg, loss_fn_args)) + invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features'])) + if invalid_args: + raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args)) + + +def _call_loss_fn(loss_fn, labels, logits, features): + """Calls loss_fn and checks the returned shape. + + Args: + loss_fn: The loss function. + labels: Processed labels Tensor. + logits: Logits Tensor of shape [batch_size, logits_dimension]. + features: Features dict. + Returns: + Loss Tensor with shape [batch_size, 1]. + """ + loss_fn_args = util.fn_args(loss_fn) + kwargs = {} + if 'features' in loss_fn_args: + kwargs['features'] = features + unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs) + batch_size = array_ops.shape(logits)[0] + loss_shape = array_ops.shape(unweighted_loss) + check_shape_op = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(loss_shape, [batch_size, 1])), + data=[ + 'loss_fn must return Tensor of shape [batch_size, 1]. Given: ', + loss_shape]) + with ops.control_dependencies([check_shape_op]): + return array_ops.identity(unweighted_loss) + diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 9dd9e433277304b320ac17d6478383531f114806..db7d96d508649f93c23b55504088551747f15a26 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -32,6 +32,8 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -79,9 +81,13 @@ def _sigmoid(logits): def _sigmoid_cross_entropy(labels, logits): + """Returns sigmoid cross entropy averaged over classes.""" sigmoid_logits = _sigmoid(logits) - return (-labels * np.log(sigmoid_logits) - -(1 - labels) * np.log(1 - sigmoid_logits)) + unreduced_result = ( + -labels * np.log(sigmoid_logits) + -(1 - labels) * np.log(1 - sigmoid_logits)) + # Mean over classes + return np.mean(unreduced_result, axis=-1, keepdims=True) class MultiLabelHead(test.TestCase): @@ -126,6 +132,37 @@ class MultiLabelHead(test.TestCase): r'Length of label_vocabulary must be n_classes \(3\). Given: 2'): head_lib.multi_label_head(n_classes=3, label_vocabulary=['foo', 'bar']) + def test_loss_fn_arg_labels_missing(self): + def _loss_fn(logits): + del logits # Unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn must contain argument: labels\. ' + r'Given arguments: \(\'logits\',\)'): + head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + + def test_loss_fn_arg_logits_missing(self): + def _loss_fn(labels): + del labels # unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn must contain argument: logits\. ' + r'Given arguments: \(\'labels\',\)'): + head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + + def test_loss_fn_arg_features_ok(self): + def _loss_fn(labels, logits, features): + del labels, logits, features # Unused + head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + + def test_loss_fn_arg_invalid(self): + def _loss_fn(labels, logits, name=None): + del labels, logits, name # Unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn has unexpected args: \[\'name\'\]'): + head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + def test_name(self): head = head_lib.multi_label_head(n_classes=4, name='foo') self.assertEqual('foo', head.name) @@ -138,6 +175,7 @@ class MultiLabelHead(test.TestCase): logits = np.array( [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32) expected_probabilities = _sigmoid(logits) + expected_export_classes = [[b'0', b'1', b'2', b'3']] * 2 spec = head.create_estimator_spec( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -145,7 +183,8 @@ class MultiLabelHead(test.TestCase): logits=logits) self.assertItemsEqual( - ('', _DEFAULT_SERVING_KEY), spec.export_outputs.keys()) + (_DEFAULT_SERVING_KEY, 'predict', 'classification'), + spec.export_outputs.keys()) # Assert predictions and export_outputs. with self.test_session() as sess: @@ -161,6 +200,29 @@ class MultiLabelHead(test.TestCase): self.assertAllClose( expected_probabilities, sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores)) + self.assertAllEqual( + expected_export_classes, + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes)) + + def test_predict_with_label_vocabulary(self): + n_classes = 4 + head = head_lib.multi_label_head( + n_classes, label_vocabulary=['foo', 'bar', 'foobar', 'barfoo']) + + logits = np.array( + [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32) + expected_export_classes = [[b'foo', b'bar', b'foobar', b'barfoo']] * 2 + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertAllEqual( + expected_export_classes, + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes)) def test_weight_should_not_impact_prediction(self): n_classes = 4 @@ -225,7 +287,7 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits expected_unweighted_loss = np.array( - [[10., 10.], [15., 0.]], dtype=np.float32) + [[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32) actual_unweighted_loss, _ = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, @@ -261,6 +323,66 @@ class MultiLabelHead(test.TestCase): actual_unweighted_loss.eval( {labels_placeholder: np.array([1, 1], dtype=np.int64)}) + def test_eval_create_loss_loss_fn(self): + """Tests head.create_loss for eval mode and custom loss_fn.""" + loss = np.array([[1.], [2.]], dtype=np.float32) + logits_input = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels_input = np.array([[1, 0], [1, 1]], dtype=np.int64) + def _loss_fn(labels, logits): + check_labels = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(labels, labels_input)), + data=[labels]) + check_logits = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(logits, logits_input)), + data=[logits]) + with ops.control_dependencies([check_labels, check_logits]): + return constant_op.constant(loss) + head = head_lib.multi_label_head(n_classes=2, loss_fn=_loss_fn) + + actual_unweighted_loss, _ = head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits_input, + labels=labels_input) + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + self.assertAllClose(loss, actual_unweighted_loss.eval()) + + def test_eval_create_loss_loss_fn_wrong_shape(self): + """Tests custom loss_fn that returns Tensor of unexpected shape.""" + loss = np.array([1., 2.], dtype=np.float32) + def _loss_fn(labels, logits): + del labels, logits # Unused + return constant_op.constant(loss) + head = head_lib.multi_label_head(n_classes=2, loss_fn=_loss_fn) + + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + actual_unweighted_loss, _ = head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels) + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'loss_fn must return Tensor of shape \[batch_size, 1\]\. ' + r'Given: \] \[2\]'): + actual_unweighted_loss.eval() + + def test_eval_labels_none(self): + """Tests that error is raised when labels is None.""" + head = head_lib.multi_label_head(n_classes=2) + + with self.assertRaisesRegexp( + ValueError, r'You must provide a labels Tensor\. Given: None\.'): + head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=None) + def _test_eval(self, head, logits, labels, expected_loss, expected_metrics): spec = head.create_estimator_spec( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -298,10 +420,8 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Average over classes, and sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes - ) + # Sum over examples. + expected_loss = np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. @@ -330,10 +450,9 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Average over classes, and sum over examples. + # Sum over examples. expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) / - n_classes + np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) ) keys = metric_keys.MetricKeys expected_metrics = { @@ -364,10 +483,9 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Average over classes, and sum over examples. + # Sum over examples. expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) / - n_classes + np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) ) keys = metric_keys.MetricKeys expected_metrics = { @@ -394,9 +512,9 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Average over classes, and sum over examples. + # Sum over examples. expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes + np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) ) keys = metric_keys.MetricKeys @@ -493,7 +611,7 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits expected_unweighted_loss = np.array( - [[10., 10.], [15., 0.]], dtype=np.float32) + [[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32) actual_unweighted_loss, _ = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.TRAIN, @@ -504,6 +622,22 @@ class MultiLabelHead(test.TestCase): self.assertAllClose( expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4) + def test_train_labels_none(self): + """Tests that error is raised when labels is None.""" + head = head_lib.multi_label_head(n_classes=2) + def _no_op_train_fn(loss): + del loss + return control_flow_ops.no_op() + + with self.assertRaisesRegexp( + ValueError, r'You must provide a labels Tensor\. Given: None\.'): + head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=None, + train_op_fn=_no_op_train_fn) + def _test_train(self, head, logits, labels, expected_loss): expected_train_result = 'my_train_op' def _train_op_fn(loss): diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index e6340424f741cd0278dbdef41dd4395e98f23246..64b2a9dee83801b5d6d852a3485fc0cc81417ff0 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -236,7 +236,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access for head, spec in zip(self._heads, all_estimator_spec): head_name = head.name for k, v in six.iteritems(spec.export_outputs): - key = '%s/%s' % (k, head_name) if k else head_name + if k == _DEFAULT_SERVING_KEY: + key = head_name + else: + key = '%s/%s' % (k, head_name) export_outputs[key] = v for k, v in six.iteritems(spec.predictions): predictions[(head_name, k)] = v diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index e86cb2b96fe1c10352337367616a0ea2ff9132cc..48027035cecffc3ce8aacf8ae917f5eb9e9b2473 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -126,8 +126,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, _DEFAULT_SERVING_KEY + '/head1', 'head1', - _DEFAULT_SERVING_KEY + '/head2', 'head2'), + (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', + 'head2', 'classification/head2', 'predict/head2'), spec.export_outputs.keys()) # Assert predictions and export_outputs. diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index c468c544d372e8bfd6adfa49a58e9bf6c5ef0a8b..44095bd00a7a098a8a89ba4d25c68a2484c00a6e 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -8,6 +8,7 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) +load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") @@ -23,6 +24,7 @@ tf_custom_op_py_library( "python/ops/factorization_ops.py", "python/ops/gmm.py", "python/ops/gmm_ops.py", + "python/ops/kmeans.py", "python/ops/wals.py", ], dso = [ @@ -199,6 +201,29 @@ tf_py_test( ) # Estimators tests +py_test( + name = "kmeans_test", + size = "medium", + srcs = ["python/ops/kmeans_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/67512932 + deps = [ + ":factorization_py", + ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_benchmark", + "//tensorflow/python:random_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "wals_test", size = "large", @@ -221,7 +246,6 @@ tf_py_test( "manual", "noasan", # times out b/63678675 "nomsan", - "notsan", ], ) diff --git a/tensorflow/contrib/factorization/__init__.py b/tensorflow/contrib/factorization/__init__.py index 486c2ea9336d19fb7273d02502f9865adc6aefed..6112c9d8300fe219c8e172a5b70e4ce4cad04eb6 100644 --- a/tensorflow/contrib/factorization/__init__.py +++ b/tensorflow/contrib/factorization/__init__.py @@ -23,22 +23,24 @@ from tensorflow.contrib.factorization.python.ops.clustering_ops import * from tensorflow.contrib.factorization.python.ops.factorization_ops import * from tensorflow.contrib.factorization.python.ops.gmm import * from tensorflow.contrib.factorization.python.ops.gmm_ops import * +from tensorflow.contrib.factorization.python.ops.kmeans import * from tensorflow.contrib.factorization.python.ops.wals import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'KMeans', 'COSINE_DISTANCE', - 'KMEANS_PLUS_PLUS_INIT', - 'RANDOM_INIT', - 'SQUARED_EUCLIDEAN_DISTANCE', - 'WALSModel', 'GMM', 'gmm', 'GmmAlgorithm', + 'KMeans', + 'KMEANS_PLUS_PLUS_INIT', + 'KMeansClustering', + 'RANDOM_INIT', + 'SQUARED_EUCLIDEAN_DISTANCE', 'WALSMatrixFactorization', + 'WALSModel', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/factorization/examples/mnist.py b/tensorflow/contrib/factorization/examples/mnist.py index 9eefbccd4dadf37bbaf728b44007d642fc1439e4..06a62db0049d5c698b2378d207e3657fe66acdbd 100644 --- a/tensorflow/contrib/factorization/examples/mnist.py +++ b/tensorflow/contrib/factorization/examples/mnist.py @@ -142,7 +142,7 @@ def inference(inp, num_clusters, hidden1_units, hidden2_units): # initial_clusters=tf.contrib.factorization.KMEANS_PLUS_PLUS_INIT, use_mini_batch=True) - (all_scores, _, clustering_scores, _, _, kmeans_init, + (all_scores, _, clustering_scores, _, kmeans_init, kmeans_training_op) = kmeans.training_graph() # Some heuristics to approximately whiten this output. all_scores = (all_scores[0] - 0.5) * 5 diff --git a/tensorflow/contrib/factorization/g3doc/kmeans.md b/tensorflow/contrib/factorization/g3doc/kmeans.md index b55c9d09ad386b84623d3648c5be83cbba8bbff9..c1843f0bf0704503d43c28d186dc826f0677711f 100644 --- a/tensorflow/contrib/factorization/g3doc/kmeans.md +++ b/tensorflow/contrib/factorization/g3doc/kmeans.md @@ -24,7 +24,11 @@ the full-batch version. approach for computing the initial cluster assignments that is expensive but is typically less prone to getting stuck in bad local minima. -We provide distributed implementations of both full-batch and mini-batch -K-Means algorithm. Both K-Means++ and random initialization are supported. -The user can also choose between **Cosine** and **Squared Euclidean** distance -metrics. +**[k-MC2](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/12147/11759)** +provides a very fast seeding method that provides high quality centers +comparable to K-Means++ seeding. k-MC2 works particularly well if it is combined +with Mini-batch K-Means. + +We provide distributed implementations of both full-batch and mini-batch K-Means +algorithm. K-Means++, k-MC2 and random initialization are supported. The user +can also choose between **Cosine** and **Squared Euclidean** distance metrics. diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc index a2136c08bbc2e91f4587b1cdacbfe3b1d1073949..dd61f59585aee2e0245cfd6797b313b972c19bc5 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -224,6 +224,58 @@ class KmeansPlusPlusInitializationOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU), KmeansPlusPlusInitializationOp); +// Implementation of one single Markov Chain for the k-MC^2 algorithm +class KMC2ChainInitializationOp : public OpKernel { + public: + explicit KMC2ChainInitializationOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64})); + } + + void Compute(OpKernelContext* context) override { + const Tensor& distances_tensor = context->input(0); + const Tensor& seed_tensor = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()), + InvalidArgument("Input distances should be a vector.")); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()), + InvalidArgument("Input seed should be a scalar.")); + const int64 num_points = distances_tensor.dim_size(0); + const int64 seed = seed_tensor.scalar()(); + OP_REQUIRES(context, num_points > 0, + InvalidArgument("Expected distances_tensor.size() > 0.")); + + random::PhiloxRandom random(seed); + random::SimplePhilox rng(&random); + + auto distances = distances_tensor.flat(); + // Set the initial state of the Markov chain to be the first candidate. + int64 selected_index = 0; + float selected_distance = distances(selected_index); + // Build a Markov chain of length num_points. + for (int64 i = 1; i < num_points; ++i) { + const float candidate_distance = distances(i); + // Set the next state of the Markov chain to be the candidate with + // probability min(1, candidate_distance/selected_distance). + if (candidate_distance > rng.RandFloat() * selected_distance) { + selected_index = i; + selected_distance = candidate_distance; + } + } + + Tensor* output_sampled_index_tensor; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), + &output_sampled_index_tensor)); + auto output = output_sampled_index_tensor->scalar(); + // Return the last state of the Markov chain as the new center. + output() = selected_index; + } +}; + +REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU), + KMC2ChainInitializationOp); + // Operator for computing the nearest neighbors for a set of points. class NearestNeighborsOp : public OpKernel { public: diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc index c4a96b048db878169acc69b4d8caed5d4e04c18f..8172a7cebb81de70c530dbdd9ce0ca3eda4bc2ce 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc @@ -116,6 +116,62 @@ RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample); #undef RUN_BM_KmeansPlusPlusInitialization #undef BENCHMARK_KMEANS_PLUS_PLUS +Graph* SetUpKMC2Initialization(int num_points) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor distances(DT_FLOAT, TensorShape({num_points})); + Tensor seed(DT_INT64, TensorShape({})); + distances.flat().setRandom(); + seed.flat().setConstant(12345); + + TF_CHECK_OK( + NodeBuilder("KMC2ChainInitializationOp", "KMC2ChainInitialization") + .Input(test::graph::Constant(g, distances)) + .Input(test::graph::Constant(g, seed)) + .Finalize(g, nullptr /* node */)); + return g; +} + +template +void BM_KMC2Initialization(int iters) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast(iters) * num_points * num_dims * + num_to_sample); + testing::UseRealTime(); + Graph* g = SetUpKMC2Initialization(num_points); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} +#define BENCHMARK_KMC2(p, c, d) \ + void BM_KMC2Initialization_##p##_##c##_##d(int iters) { \ + BM_KMC2Initialization(iters); \ + } \ + BENCHMARK(BM_KMC2Initialization_##p##_##c##_##d); + +#define RUN_BM_KMC2Initialization \ + BENCHMARK_KMC2(k10Points, k2Centers, k100Dim); \ + BENCHMARK_KMC2(k10Points, k5Centers, k100Dim); \ + BENCHMARK_KMC2(k10Points, k10Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k10Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k20Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k50Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k1kCenters, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k1kCenters, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k1kCenters, k100Dim) + +RUN_BM_KMC2Initialization; +#undef RUN_BM_KMC2Initialization +#undef BENCHMARK_KMC2 + Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers, int k) { Graph* g = new Graph(OpRegistry::Global()); diff --git a/tensorflow/contrib/factorization/ops/clustering_ops.cc b/tensorflow/contrib/factorization/ops/clustering_ops.cc index f2dfcf7ed0fb05264b10dee9980a246a5f2e49fa..2686702c1d5768f661dac610c96089eb02e360d7 100644 --- a/tensorflow/contrib/factorization/ops/clustering_ops.cc +++ b/tensorflow/contrib/factorization/ops/clustering_ops.cc @@ -44,6 +44,25 @@ num_retries_per_sample: Scalar. For each row that is sampled, this parameter samples: Matrix of shape (num_to_sample, d). The sampled rows. )"); +REGISTER_OP("KMC2ChainInitialization") + .Input("distances: float32") + .Input("seed: int64") + .Output("index: int64") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"( +Returns the index of a data point that should be added to the seed set. + +Entries in distances are assumed to be squared distances of candidate points to +the already sampled centers in the seed set. The op constructs one Markov chain +of the k-MC^2 algorithm and returns the index of one candidate point to be added +as an additional cluster center. + +distances: Vector with squared distances to the closest previously sampled + cluster center for each candidate point. +seed: Scalar. Seed for initializing the random number generator. +index: Scalar with the index of the sampled point. +)"); + REGISTER_OP("NearestNeighbors") .Input("points: float32") .Input("centers: float32") diff --git a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py index 450f64063a2a357e422cd14761864d511c0e6cce..1322f7ce5f83d82c76040a30699137cd2bf491b5 100644 --- a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py +++ b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py @@ -55,6 +55,63 @@ class KmeansPlusPlusInitializationTest(test.TestCase): self.runTestWithSeed(seed) +class KMC2InitializationTest(test.TestCase): + + def runTestWithSeed(self, seed): + with self.test_session(): + distances = np.zeros(1000).astype(np.float32) + distances[6] = 10e7 + distances[4] = 10e3 + + sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) + self.assertEquals(sampled_point.eval(), 6) + distances[6] = 0.0 + sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) + self.assertEquals(sampled_point.eval(), 4) + + def testBasic(self): + for seed in range(100): + self.runTestWithSeed(seed) + + +class KMC2InitializationLargeTest(test.TestCase): + + def setUp(self): + self._distances = np.zeros(1001) + self._distances[500] = 100.0 + self._distances[1000] = 50.0 + + def testBasic(self): + with self.test_session(): + counts = {} + seed = 0 + for i in range(50): + sample = clustering_ops.kmc2_chain_initialization( + self._distances, seed + i).eval() + counts[sample] = counts.get(sample, 0) + 1 + self.assertEquals(len(counts), 2) + self.assertTrue(500 in counts) + self.assertTrue(1000 in counts) + self.assertGreaterEqual(counts[500], 5) + self.assertGreaterEqual(counts[1000], 5) + + +class KMC2InitializationCornercaseTest(test.TestCase): + + def setUp(self): + self._distances = np.zeros(10) + + def runTestWithSeed(self, seed): + with self.test_session(): + sampled_point = clustering_ops.kmc2_chain_initialization( + self._distances, seed) + self.assertEquals(sampled_point.eval(), 0) + + def testBasic(self): + for seed in range(100): + self.runTestWithSeed(seed) + + # A simple test that can be verified by hand. class NearestCentersTest(test.TestCase): diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index e5c918066217371b076aa23c2e28650608f93fb0..96cc80ce241347ebca5b68140f1b1c8b9898ae72 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -50,6 +50,10 @@ COSINE_DISTANCE = 'cosine' RANDOM_INIT = 'random' KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' +KMC2_INIT = 'kmc2' + +# The name of the variable holding the cluster centers. Used by the Estimator. +CLUSTERS_VAR_NAME = 'clusters' class KMeans(object): @@ -63,7 +67,8 @@ class KMeans(object): use_mini_batch=False, mini_batch_steps_per_iteration=1, random_seed=0, - kmeans_plus_plus_num_retries=2): + kmeans_plus_plus_num_retries=2, + kmc2_chain_length=200): """Creates an object for generating KMeans clustering graph. This class implements the following variants of K-means algorithm: @@ -92,7 +97,8 @@ class KMeans(object): exactly like a full-batch version. Args: - inputs: An input tensor or list of input tensors + inputs: An input tensor or list of input tensors. It is assumed that the + data points have been previously randomly permuted. num_clusters: An integer tensor specifying the number of clusters. This argument is ignored if initial_clusters is a tensor or numpy array. initial_clusters: Specifies the clusters used during initialization. One @@ -101,6 +107,7 @@ class KMeans(object): - a function f(inputs, k) that returns up to k centers from `inputs`. - "random": Choose centers randomly from `inputs`. - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`. + - "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`. In the last three cases, one batch of `inputs` may not yield `num_clusters` centers, in which case initialization will require multiple batches until enough centers are chosen. In the case of @@ -118,13 +125,17 @@ class KMeans(object): additional points to draw from the current distribution before selecting the best. If a negative value is specified, a heuristic is used to sample O(log(num_to_sample)) additional points. + kmc2_chain_length: Determines how many candidate points are used by the + k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch + contains less points, one new cluster center is generated from the + (mini-)batch. Raises: ValueError: An invalid argument was passed to initial_clusters or distance_metric. """ if isinstance(initial_clusters, str) and initial_clusters not in [ - RANDOM_INIT, KMEANS_PLUS_PLUS_INIT + RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT ]: raise ValueError( "Unsupported initialization algorithm '%s'" % initial_clusters) @@ -138,6 +149,7 @@ class KMeans(object): self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration) self._random_seed = random_seed self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self._kmc2_chain_length = kmc2_chain_length @classmethod def _distance_graph(cls, inputs, clusters, distance_metric): @@ -279,7 +291,7 @@ class KMeans(object): """ init_value = array_ops.constant([], dtype=dtypes.float32) cluster_centers = variable_scope.variable( - init_value, name='clusters', validate_shape=False) + init_value, name=CLUSTERS_VAR_NAME, validate_shape=False) cluster_centers_initialized = variable_scope.variable( False, dtype=dtypes.bool, name='initialized') @@ -299,9 +311,10 @@ class KMeans(object): else: cluster_centers_updated = cluster_centers update_in_steps = None - cluster_counts = (variable_scope.variable( - array_ops.ones([num_clusters], dtype=dtypes.int64)) - if self._use_mini_batch else None) + cluster_counts = ( + variable_scope.variable( + array_ops.ones([num_clusters], dtype=dtypes.int64)) + if self._use_mini_batch else None) return (cluster_centers, cluster_centers_initialized, cluster_counts, cluster_centers_updated, update_in_steps) @@ -337,7 +350,6 @@ class KMeans(object): assigned cluster instead. cluster_centers_initialized: scalar indicating whether clusters have been initialized. - cluster_centers_var: a Variable holding the cluster centers. init_op: an op to initialize the clusters. training_op: an op that runs an iteration of training. """ @@ -357,7 +369,7 @@ class KMeans(object): init_op = _InitializeClustersOpFactory( self._inputs, num_clusters, initial_clusters, self._distance_metric, self._random_seed, self._kmeans_plus_plus_num_retries, - cluster_centers_var, cluster_centers_updated, + self._kmc2_chain_length, cluster_centers_var, cluster_centers_updated, cluster_centers_initialized).op() cluster_centers = cluster_centers_var @@ -381,7 +393,7 @@ class KMeans(object): inputs, num_clusters, cluster_idx, cluster_centers_var) return (all_scores, cluster_idx, scores, cluster_centers_initialized, - cluster_centers_var, init_op, training_op) + init_op, training_op) def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var, cluster_centers_updated, total_counts): @@ -518,8 +530,9 @@ class KMeans(object): array_ops.reshape(array_ops.shape(inp)[0], [-1])), [-1, 1]), cluster_idx, num_clusters)) with ops.colocate_with(cluster_centers, ignore_existing=True): - new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast( - math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon) + new_clusters_centers = math_ops.add_n(cluster_sums) / ( + math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + + epsilon) if self._clusters_l2_normalized(): new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1) return state_ops.assign(cluster_centers, new_clusters_centers) @@ -546,9 +559,12 @@ class _InitializeClustersOpFactory(object): cluster_centers_initialized := true """ + # TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case. + def __init__(self, inputs, num_clusters, initial_clusters, distance_metric, - random_seed, kmeans_plus_plus_num_retries, cluster_centers, - cluster_centers_updated, cluster_centers_initialized): + random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length, + cluster_centers, cluster_centers_updated, + cluster_centers_initialized): """Creates an op factory. Args: @@ -558,6 +574,7 @@ class _InitializeClustersOpFactory(object): distance_metric: See KMeans constructor. random_seed: See KMeans constructor. kmeans_plus_plus_num_retries: See KMeans constructor. + kmc2_chain_length: See KMeans constructor. cluster_centers: The TF variable holding the initial centers. It may already contain some centers when the op is executed. cluster_centers_updated: A second TF variable to hold a copy of the @@ -573,6 +590,7 @@ class _InitializeClustersOpFactory(object): self._distance_metric = distance_metric self._random_seed = random_seed self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self._kmc2_chain_length = kmc2_chain_length self._cluster_centers = cluster_centers self._cluster_centers_updated = cluster_centers_updated self._cluster_centers_initialized = cluster_centers_initialized @@ -602,6 +620,90 @@ class _InitializeClustersOpFactory(object): math_ops.to_int64(self._num_remaining), self._random_seed, self._kmeans_plus_plus_num_retries) + def _kmc2_multiple_centers(self): + """Adds new initial cluster centers using the k-MC2 algorithm. + + In each call to the op, the provided batch is split into subsets based on + the specified `kmc2_chain_length`. On each subset, a single Markov chain of + the k-MC2 algorithm is used to add *one* new center cluster center. If there + are less than `kmc2_chain_length` points in the subset, a single center is + added using one Markov chain on the full input. It is assumed that the + provided batch has previously been randomly permuted. Otherwise, k-MC2 may + return suboptimal centers. + + Returns: + An op that adds new cluster centers. + """ + # The op only operates on the first shard of data. + first_shard = self._inputs[0] + # Number of points in the input that can be used. + batch_size = array_ops.shape(first_shard)[0] + # Maximum number of subsets such that the size of each subset is at least + # `kmc2_chain_length`. Final subsets may be larger. + max_to_sample = math_ops.cast( + batch_size / self._kmc2_chain_length, dtype=dtypes.int32) + # We sample at least one new center and at most all remaining centers. + num_to_sample = math_ops.maximum( + math_ops.minimum(self._num_remaining, max_to_sample), 1) + + def _cond(i, _): + """Stopping condition for the while loop.""" + return math_ops.less(i, num_to_sample) + + def _body(i, _): + """Body that adds a single new center based on a subset.""" + + def _sample_random(): + """Returns a random point as a cluster center.""" + # By assumption the batch is reshuffled and _sample_random is always + # called for i=0. Hence, we simply return the first point. + new_center = array_ops.reshape(first_shard[0], [1, -1]) + if self._distance_metric == COSINE_DISTANCE: + new_center = nn_impl.l2_normalize(new_center, dim=1) + return new_center + + def _sample_kmc2_chain(): + """Returns previous centers as well as a new center sampled using k-MC2. + """ + # Extract the subset from the underlying batch. + start = i * self._kmc2_chain_length + end = start + self._kmc2_chain_length + subset = first_shard[start:end] + # Compute the distances from points in the subset to previous centers. + _, distances = gen_clustering_ops.nearest_neighbors( + subset, self._cluster_centers, 1) + # Sample index of new center using k-MC2 Markov chain. + new_center_index = gen_clustering_ops.kmc2_chain_initialization( + array_ops.squeeze(distances), self._random_seed) + # Extract actual new center. + newly_sampled_center = array_ops.reshape(subset[new_center_index], + [1, -1]) + # Return concatenation with previously sampled centers. + if self._distance_metric == COSINE_DISTANCE: + newly_sampled_center = nn_impl.l2_normalize( + newly_sampled_center, dim=1) + return array_ops.concat([self._cluster_centers, newly_sampled_center], + 0) + + # Obtain a random point if there are no previously sampled centers. + # Otherwise, construct a k-MC2 Markov chain. + new_centers = control_flow_ops.cond( + math_ops.equal(self._num_selected, 0), _sample_random, + _sample_kmc2_chain) + # Assign new cluster centers to underlying variable. + assigned_centers = state_ops.assign( + self._cluster_centers, new_centers, validate_shape=False) + if self._cluster_centers_updated is not self._cluster_centers: + assigned_centers = state_ops.assign( + self._cluster_centers_updated, + assigned_centers, + validate_shape=False) + return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0] + + # Add num_to_sample new data points. + _, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0]) + return num_remaining + def _greedy_batch_sampler(self, sampler): # If the input dataset size is smaller than the number of centers # remaining, choose the entire input dataset as centers. This can happen @@ -655,7 +757,10 @@ class _InitializeClustersOpFactory(object): with ops.control_dependencies([ check_ops.assert_positive(self._num_remaining), ]): - num_now_remaining = self._add_new_centers() + if self._initial_clusters == KMC2_INIT: + num_now_remaining = self._kmc2_multiple_centers() + else: + num_now_remaining = self._add_new_centers() return control_flow_ops.cond( math_ops.equal(num_now_remaining, 0), lambda: state_ops.assign(self._cluster_centers_initialized, True), diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5413fc3f2642443621b33d325e3d8c893fd6ac --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -0,0 +1,397 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A canned Estimator for k-means clustering.""" + +# TODO(ccolby): Move clustering_ops.py into this file and streamline the code. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from tensorflow.contrib.factorization.python.ops import clustering_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.summary import summary +from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util + + +class _LossRelativeChangeHook(session_run_hook.SessionRunHook): + """Stops when the change in loss goes below a tolerance.""" + + def __init__(self, loss_tensor, tolerance): + """Creates a _LossRelativeChangeHook. + + Args: + loss_tensor: A scalar tensor of the loss value. + tolerance: A relative tolerance of loss change between iterations. + """ + self._loss_tensor = loss_tensor + self._tolerance = tolerance + self._prev_loss = None + + def before_run(self, run_context): + del run_context # unused + return session_run_hook.SessionRunArgs(self._loss_tensor) + + def after_run(self, run_context, run_values): + loss = run_values.results + assert loss is not None + if self._prev_loss: + relative_change = (abs(loss - self._prev_loss) / + (1 + abs(self._prev_loss))) + if relative_change < self._tolerance: + run_context.request_stop() + self._prev_loss = loss + + +class _InitializeClustersHook(session_run_hook.SessionRunHook): + """Initializes the cluster centers. + + The chief repeatedly invokes an initialization op until all cluster centers + are initialized. The workers wait for the initialization phase to complete. + """ + + def __init__(self, init_op, is_initialized_var, is_chief): + """Creates an _InitializeClustersHook. + + Args: + init_op: An op that, when run, will choose some initial cluster centers. + This op may need to be run multiple times to choose all the centers. + is_initialized_var: A boolean variable reporting whether all initial + centers have been chosen. + is_chief: A boolean specifying whether this task is the chief. + """ + self._init_op = init_op + self._is_initialized_var = is_initialized_var + self._is_chief = is_chief + + def after_create_session(self, session, coord): + del coord # unused + assert self._init_op.graph is ops.get_default_graph() + assert self._is_initialized_var.graph is self._init_op.graph + while True: + try: + if session.run(self._is_initialized_var): + break + elif self._is_chief: + session.run(self._init_op) + else: + time.sleep(1) + except RuntimeError as e: + logging.info(e) + + +def _parse_tensor_or_dict(features): + """Helper function to convert the input points into a usable format. + + Args: + features: The input points. + + Returns: + If `features` is a dict of `k` features, each of which is a vector of `n` + scalars, the return value is a Tensor of shape `(n, k)` representing `n` + input points, where the items in the `k` dimension are sorted + lexicographically by `features` key. If `features` is not a dict, it is + returned unmodified. + """ + if isinstance(features, dict): + keys = sorted(features.keys()) + with ops.colocate_with(features[keys[0]]): + features = array_ops.concat([features[k] for k in keys], axis=1) + return features + + +class _ModelFn(object): + """Model function for the estimator.""" + + def __init__(self, num_clusters, initial_clusters, distance_metric, + random_seed, use_mini_batch, mini_batch_steps_per_iteration, + kmeans_plus_plus_num_retries, relative_tolerance): + self._num_clusters = num_clusters + self._initial_clusters = initial_clusters + self._distance_metric = distance_metric + self._random_seed = random_seed + self._use_mini_batch = use_mini_batch + self._mini_batch_steps_per_iteration = mini_batch_steps_per_iteration + self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self._relative_tolerance = relative_tolerance + + def model_fn(self, features, mode, config): + """Model function for the estimator. + + Note that this does not take a `1abels` arg. This works, but `input_fn` must + return either `features` or, equivalently, `(features, None)`. + + Args: + features: The input points. See @{tf.estimator.Estimator}. + mode: See @{tf.estimator.Estimator}. + config: See @{tf.estimator.Estimator}. + + Returns: + A @{tf.estimator.EstimatorSpec} (see @{tf.estimator.Estimator}) specifying + this behavior: + * `train_op`: Execute one mini-batch or full-batch run of Lloyd's + algorithm. + * `loss`: The sum of the squared distances from each input point to its + closest center. + * `eval_metric_ops`: Maps `SCORE` to `loss`. + * `predictions`: Maps `ALL_DISTANCES` to the distance from each input + point to each cluster center; maps `CLUSTER_INDEX` to the index of + the closest cluster center for each input point. + """ + # input_points is a single Tensor. Therefore, the sharding functionality + # in clustering_ops is unused, and some of the values below are lists of a + # single item. + input_points = _parse_tensor_or_dict(features) + + # Let N = the number of input_points. + # all_distances: A list of one matrix of shape (N, num_clusters). Each value + # is the distance from an input point to a cluster center. + # model_predictions: A list of one vector of shape (N). Each value is the + # cluster id of an input point. + # losses: Similar to cluster_idx but provides the distance to the cluster + # center. + # is_initialized: scalar indicating whether the initial cluster centers + # have been chosen; see init_op. + # cluster_centers_var: a Variable containing the cluster centers. + # init_op: an op to choose the initial cluster centers. A single worker + # repeatedly executes init_op until is_initialized becomes True. + # training_op: an op that runs an iteration of training, either an entire + # Lloyd iteration or a mini-batch of a Lloyd iteration. Multiple workers + # may execute this op, but only after is_initialized becomes True. + (all_distances, model_predictions, losses, is_initialized, init_op, + training_op) = clustering_ops.KMeans( + inputs=input_points, + num_clusters=self._num_clusters, + initial_clusters=self._initial_clusters, + distance_metric=self._distance_metric, + use_mini_batch=self._use_mini_batch, + mini_batch_steps_per_iteration=self._mini_batch_steps_per_iteration, + random_seed=self._random_seed, + kmeans_plus_plus_num_retries=self._kmeans_plus_plus_num_retries + ).training_graph() + + loss = math_ops.reduce_sum(losses) + summary.scalar('loss/raw', loss) + + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) + training_op = control_flow_ops.with_dependencies([training_op, incr_step], + loss) + + training_hooks = [ + _InitializeClustersHook(init_op, is_initialized, config.is_chief) + ] + if self._relative_tolerance is not None: + training_hooks.append( + _LossRelativeChangeHook(loss, self._relative_tolerance)) + + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions={ + KMeansClustering.ALL_DISTANCES: all_distances[0], + KMeansClustering.CLUSTER_INDEX: model_predictions[0], + }, + loss=loss, + train_op=training_op, + eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)}, + training_hooks=training_hooks) + + +# TODO(agarwal,ands): support sharded input. +class KMeansClustering(estimator.Estimator): + """An Estimator for K-Means clustering.""" + + # Valid values for the distance_metric constructor argument. + SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE + COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE + + # Values for initial_clusters constructor argument. + RANDOM_INIT = clustering_ops.RANDOM_INIT + KMEANS_PLUS_PLUS_INIT = clustering_ops.KMEANS_PLUS_PLUS_INIT + + # Metric returned by evaluate(): The sum of the squared distances from each + # input point to its closest center. + SCORE = 'score' + + # Keys returned by predict(). + # ALL_DISTANCES: The distance from each input point to each cluster center. + # CLUSTER_INDEX: The index of the closest cluster center for each input point. + CLUSTER_INDEX = 'cluster_index' + ALL_DISTANCES = 'all_distances' + + def __init__(self, + num_clusters, + model_dir=None, + initial_clusters=RANDOM_INIT, + distance_metric=SQUARED_EUCLIDEAN_DISTANCE, + random_seed=0, + use_mini_batch=True, + mini_batch_steps_per_iteration=1, + kmeans_plus_plus_num_retries=2, + relative_tolerance=None, + config=None): + """Creates an Estimator for running KMeans training and inference. + + This Estimator implements the following variants of the K-means algorithm: + + If `use_mini_batch` is False, it runs standard full batch K-means. Each + training step runs a single iteration of K-Means and must process the full + input at once. To run in this mode, the `input_fn` passed to `train` must + return the entire input dataset. + + If `use_mini_batch` is True, it runs a generalization of the mini-batch + K-means algorithm. It runs multiple iterations, where each iteration is + composed of `mini_batch_steps_per_iteration` steps. Each training step + accumulates the contribution from one mini-batch into temporary storage. + Every `mini_batch_steps_per_iteration` steps, the cluster centers are + updated and the temporary storage cleared for the next iteration. Note + that: + * If `mini_batch_steps_per_iteration=1`, the algorithm reduces to the + standard K-means mini-batch algorithm. + * If `mini_batch_steps_per_iteration = num_inputs / batch_size`, the + algorithm becomes an asynchronous version of the full-batch algorithm. + However, there is no guarantee by this implementation that each input + is seen exactly once per iteration. Also, different updates are applied + asynchronously without locking. So this asynchronous version may not + behave exactly like a full-batch version. + + Args: + num_clusters: An integer tensor specifying the number of clusters. This + argument is ignored if `initial_clusters` is a tensor or numpy array. + model_dir: The directory to save the model results and log files. + initial_clusters: Specifies how the initial cluster centers are chosen. + One of the following: + * a tensor or numpy array with the initial cluster centers. + * a callable `f(inputs, k)` that selects and returns up to `k` centers + from an input batch. `f` is free to return any number of centers + from `0` to `k`. It will be invoked on successive input batches + as necessary until all `num_clusters` centers are chosen. + * `KMeansClustering.RANDOM_INIT`: Choose centers randomly from an input + batch. If the batch size is less than `num_clusters` then the + entire batch is chosen to be initial cluster centers and the + remaining centers are chosen from successive input batches. + * `KMeansClustering.KMEANS_PLUS_PLUS_INIT`: Use kmeans++ to choose + centers from the first input batch. If the batch size is less + than `num_clusters`, a TensorFlow runtime error occurs. + distance_metric: The distance metric used for clustering. One of: + * `KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`: Euclidean distance + between vectors `u` and `v` is defined as `||u - v||_2` which is + the square root of the sum of the absolute squares of the elements' + difference. + * `KMeansClustering.COSINE_DISTANCE`: Cosine distance between vectors + `u` and `v` is defined as `1 - (u . v) / (||u||_2 ||v||_2)`. + random_seed: Python integer. Seed for PRNG used to initialize centers. + use_mini_batch: A boolean specifying whether to use the mini-batch k-means + algorithm. See explanation above. + mini_batch_steps_per_iteration: The number of steps after which the + updated cluster centers are synced back to a master copy. Used only if + `use_mini_batch=True`. See explanation above. + kmeans_plus_plus_num_retries: For each point that is sampled during + kmeans++ initialization, this parameter specifies the number of + additional points to draw from the current distribution before selecting + the best. If a negative value is specified, a heuristic is used to + sample `O(log(num_to_sample))` additional points. Used only if + `initial_clusters=KMeansClustering.KMEANS_PLUS_PLUS_INIT`. + relative_tolerance: A relative tolerance of change in the loss between + iterations. Stops learning if the loss changes less than this amount. + This may not work correctly if `use_mini_batch=True`. + config: See @{tf.estimator.Estimator}. + + Raises: + ValueError: An invalid argument was passed to `initial_clusters` or + `distance_metric`. + """ + if isinstance(initial_clusters, str) and initial_clusters not in [ + KMeansClustering.RANDOM_INIT, KMeansClustering.KMEANS_PLUS_PLUS_INIT + ]: + raise ValueError( + "Unsupported initialization algorithm '%s'" % initial_clusters) + if distance_metric not in [ + KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + KMeansClustering.COSINE_DISTANCE + ]: + raise ValueError("Unsupported distance metric '%s'" % distance_metric) + super(KMeansClustering, self).__init__( + model_fn=_ModelFn( + num_clusters, initial_clusters, distance_metric, random_seed, + use_mini_batch, mini_batch_steps_per_iteration, + kmeans_plus_plus_num_retries, relative_tolerance).model_fn, + model_dir=model_dir, + config=config) + + def _predict_one_key(self, input_fn, predict_key): + for result in self.predict(input_fn=input_fn, predict_keys=[predict_key]): + yield result[predict_key] + + def predict_cluster_index(self, input_fn): + """Finds the index of the closest cluster center to each input point. + + Args: + input_fn: Input points. See @{tf.estimator.Estimator.predict}. + + Yields: + The index of the closest cluster center for each input point. + """ + for index in self._predict_one_key(input_fn, + KMeansClustering.CLUSTER_INDEX): + yield index + + def score(self, input_fn): + """Returns the sum of squared distances to nearest clusters. + + Note that this function is different from the corresponding one in sklearn + which returns the negative sum. + + Args: + input_fn: Input points. See @{tf.estimator.Estimator.evaluate}. Only one + batch is retrieved. + + Returns: + The sum of the squared distance from each point in the first batch of + inputs to its nearest cluster center. + """ + return self.evaluate(input_fn=input_fn, steps=1)[KMeansClustering.SCORE] + + def transform(self, input_fn): + """Transforms each input point to its distances to all cluster centers. + + Note that if `distance_metric=KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`, + this + function returns the squared Euclidean distance while the corresponding + sklearn function returns the Euclidean distance. + + Args: + input_fn: Input points. See @{tf.estimator.Estimator.predict}. + + Yields: + The distances from each input point to each cluster center. + """ + for distances in self._predict_one_key(input_fn, + KMeansClustering.ALL_DISTANCES): + yield distances + + def cluster_centers(self): + """Returns the cluster centers.""" + return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME) diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4709d7942583f1406a3fa0ff3a078d0283872ea6 --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -0,0 +1,575 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for KMeans.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import time + +import numpy as np +from sklearn.cluster import KMeans as SklearnKMeans + +# pylint: disable=g-import-not-at-top +from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_lib +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import flags +from tensorflow.python.platform import test +from tensorflow.python.training import input as input_lib +from tensorflow.python.training import queue_runner + +FLAGS = flags.FLAGS + + +def normalize(x): + return x / np.sqrt(np.sum(x * x, axis=-1, keepdims=True)) + + +def cosine_similarity(x, y): + return np.dot(normalize(x), np.transpose(normalize(y))) + + +def make_random_centers(num_centers, num_dims, center_norm=500): + return np.round( + np.random.rand(num_centers, num_dims).astype(np.float32) * center_norm) + + +def make_random_points(centers, num_points, max_offset=20): + num_centers, num_dims = centers.shape + assignments = np.random.choice(num_centers, num_points) + offsets = np.round( + np.random.randn(num_points, num_dims).astype(np.float32) * max_offset) + return (centers[assignments] + offsets, assignments, np.add.reduce( + offsets * offsets, 1)) + + +class KMeansTestBase(test.TestCase): + + def input_fn(self, + batch_size=None, + points=None, + randomize=None, + num_epochs=None): + """Returns an input_fn that randomly selects batches from given points.""" + batch_size = batch_size or self.batch_size + points = points if points is not None else self.points + num_points = points.shape[0] + if randomize is None: + randomize = (self.use_mini_batch and + self.mini_batch_steps_per_iteration <= 1) + + def _fn(): + x = constant_op.constant(points) + if batch_size == num_points: + return input_lib.limit_epochs(x, num_epochs=num_epochs), None + if randomize: + indices = random_ops.random_uniform( + constant_op.constant([batch_size]), + minval=0, + maxval=num_points - 1, + dtype=dtypes.int32, + seed=10) + else: + # We need to cycle through the indices sequentially. We create a queue + # to maintain the list of indices. + q = data_flow_ops.FIFOQueue(num_points, dtypes.int32, ()) + + # Conditionally initialize the Queue. + def _init_q(): + with ops.control_dependencies( + [q.enqueue_many(math_ops.range(num_points))]): + return control_flow_ops.no_op() + + init_q = control_flow_ops.cond(q.size() <= 0, _init_q, + control_flow_ops.no_op) + with ops.control_dependencies([init_q]): + offsets = q.dequeue_many(batch_size) + with ops.control_dependencies([q.enqueue_many(offsets)]): + indices = array_ops.identity(offsets) + batch = array_ops.gather(x, indices) + return (input_lib.limit_epochs(batch, num_epochs=num_epochs), None) + + return _fn + + @staticmethod + def config(tf_random_seed): + return run_config.RunConfig().replace(tf_random_seed=tf_random_seed) + + @property + def initial_clusters(self): + return kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT + + @property + def batch_size(self): + return self.num_points + + @property + def use_mini_batch(self): + return False + + @property + def mini_batch_steps_per_iteration(self): + return 1 + + +class KMeansTest(KMeansTestBase): + + def setUp(self): + np.random.seed(3) + self.num_centers = 5 + self.num_dims = 2 + self.num_points = 1000 + self.true_centers = make_random_centers(self.num_centers, self.num_dims) + self.points, _, self.scores = make_random_points(self.true_centers, + self.num_points) + self.true_score = np.add.reduce(self.scores) + + def _kmeans(self, relative_tolerance=None): + return kmeans_lib.KMeansClustering( + self.num_centers, + initial_clusters=self.initial_clusters, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + random_seed=24, + relative_tolerance=relative_tolerance) + + def test_clusters(self): + kmeans = self._kmeans() + kmeans.train(input_fn=self.input_fn(), steps=1) + clusters = kmeans.cluster_centers() + self.assertAllEqual(list(clusters.shape), [self.num_centers, self.num_dims]) + + def test_fit(self): + kmeans = self._kmeans() + kmeans.train(input_fn=self.input_fn(), steps=1) + score1 = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) + steps = 10 * self.num_points // self.batch_size + kmeans.train(input_fn=self.input_fn(), steps=steps) + score2 = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) + self.assertTrue(score1 > score2) + self.assertNear(self.true_score, score2, self.true_score * 0.05) + + def test_monitor(self): + if self.use_mini_batch: + # We don't test for use_mini_batch case since the loss value can be noisy. + return + kmeans = kmeans_lib.KMeansClustering( + self.num_centers, + initial_clusters=self.initial_clusters, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + config=self.config(14), + random_seed=12, + relative_tolerance=1e-4) + + kmeans.train( + input_fn=self.input_fn(), + # Force it to train until the relative tolerance monitor stops it. + steps=None) + score = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) + self.assertNear(self.true_score, score, self.true_score * 0.01) + + def test_infer(self): + kmeans = self._kmeans() + # Make a call to fit to initialize the cluster centers. + max_steps = 1 + kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + clusters = kmeans.cluster_centers() + + # Make a small test set + num_points = 10 + points, true_assignments, true_offsets = make_random_points( + clusters, num_points) + input_fn = self.input_fn(batch_size=num_points, points=points, num_epochs=1) + # Test predict + assignments = list(kmeans.predict_cluster_index(input_fn)) + self.assertAllEqual(assignments, true_assignments) + + # Test score + score = kmeans.score(input_fn=lambda: (constant_op.constant(points), None)) + self.assertNear(score, np.sum(true_offsets), 0.01 * score) + + # Test transform + transform = list(kmeans.transform(input_fn)) + true_transform = np.maximum( + 0, + np.sum(np.square(points), axis=1, keepdims=True) - + 2 * np.dot(points, np.transpose(clusters)) + np.transpose( + np.sum(np.square(clusters), axis=1, keepdims=True))) + self.assertAllClose(transform, true_transform, rtol=0.05, atol=10) + + +class KMeansTestMultiStageInit(KMeansTestBase): + + def test_random(self): + points = np.array( + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]], dtype=np.float32) + kmeans = kmeans_lib.KMeansClustering( + num_clusters=points.shape[0], + initial_clusters=kmeans_lib.KMeansClustering.RANDOM_INIT, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=True, + mini_batch_steps_per_iteration=100, + random_seed=24, + relative_tolerance=None) + kmeans.train( + input_fn=self.input_fn(batch_size=1, points=points, randomize=False), + steps=1) + clusters = kmeans.cluster_centers() + self.assertAllEqual(points, clusters) + + def test_kmeans_plus_plus_batch_just_right(self): + points = np.array([[1, 2]], dtype=np.float32) + kmeans = kmeans_lib.KMeansClustering( + num_clusters=points.shape[0], + initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=True, + mini_batch_steps_per_iteration=100, + random_seed=24, + relative_tolerance=None) + kmeans.train( + input_fn=self.input_fn(batch_size=1, points=points, randomize=False), + steps=1) + clusters = kmeans.cluster_centers() + self.assertAllEqual(points, clusters) + + def test_kmeans_plus_plus_batch_too_small(self): + points = np.array( + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]], dtype=np.float32) + kmeans = kmeans_lib.KMeansClustering( + num_clusters=points.shape[0], + initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=True, + mini_batch_steps_per_iteration=100, + random_seed=24, + relative_tolerance=None) + with self.assertRaisesOpError(AssertionError): + kmeans.train( + input_fn=self.input_fn(batch_size=4, points=points, randomize=False), + steps=1) + + +class MiniBatchKMeansTest(KMeansTest): + + @property + def batch_size(self): + return 50 + + @property + def use_mini_batch(self): + return True + + +class FullBatchAsyncKMeansTest(KMeansTest): + + @property + def batch_size(self): + return 50 + + @property + def use_mini_batch(self): + return True + + @property + def mini_batch_steps_per_iteration(self): + return self.num_points // self.batch_size + + +class KMeansCosineDistanceTest(KMeansTestBase): + + def setUp(self): + self.points = np.array( + [[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2], [0.1, 2.5], [0.2, 2], + [0.1, 3], [0.2, 4]], + dtype=np.float32) + self.num_points = self.points.shape[0] + self.true_centers = np.array( + [ + normalize( + np.mean(normalize(self.points)[0:4, :], axis=0, + keepdims=True))[0], + normalize( + np.mean(normalize(self.points)[4:, :], axis=0, + keepdims=True))[0] + ], + dtype=np.float32) + self.true_assignments = np.array([0] * 4 + [1] * 4) + self.true_score = len(self.points) - np.tensordot( + normalize(self.points), self.true_centers[self.true_assignments]) + + self.num_centers = 2 + self.kmeans = kmeans_lib.KMeansClustering( + self.num_centers, + initial_clusters=kmeans_lib.KMeansClustering.RANDOM_INIT, + distance_metric=kmeans_lib.KMeansClustering.COSINE_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + config=self.config(3)) + + def test_fit(self): + max_steps = 10 * self.num_points // self.batch_size + self.kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + centers = normalize(self.kmeans.cluster_centers()) + centers = centers[centers[:, 0].argsort()] + true_centers = self.true_centers[self.true_centers[:, 0].argsort()] + self.assertAllClose(centers, true_centers, atol=0.04) + + def test_transform(self): + self.kmeans.train(input_fn=self.input_fn(), steps=10) + centers = normalize(self.kmeans.cluster_centers()) + true_transform = 1 - cosine_similarity(self.points, centers) + transform = list( + self.kmeans.transform( + input_fn=self.input_fn(batch_size=self.num_points, num_epochs=1))) + self.assertAllClose(transform, true_transform, atol=1e-3) + + def test_predict(self): + max_steps = 10 * self.num_points // self.batch_size + self.kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + centers = normalize(self.kmeans.cluster_centers()) + + assignments = list( + self.kmeans.predict_cluster_index( + input_fn=self.input_fn(num_epochs=1, batch_size=self.num_points))) + self.assertAllClose( + centers[assignments], + self.true_centers[self.true_assignments], + atol=1e-2) + + centers = centers[centers[:, 0].argsort()] + true_centers = self.true_centers[self.true_centers[:, 0].argsort()] + self.assertAllClose(centers, true_centers, atol=0.04) + score = self.kmeans.score( + input_fn=self.input_fn(batch_size=self.num_points)) + self.assertAllClose(score, self.true_score, atol=1e-2) + + def test_predict_kmeans_plus_plus(self): + # Most points are concetrated near one center. KMeans++ is likely to find + # the less populated centers. + points = np.array( + [[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3], [-3.1, -3.2], + [-2.8, -3.], [-2.9, -3.1], [-3., -3.1], [-3., -3.1], [-3.2, -3.], + [-3., -3.]], + dtype=np.float32) + true_centers = np.array( + [ + normalize( + np.mean(normalize(points)[0:2, :], axis=0, keepdims=True))[0], + normalize( + np.mean(normalize(points)[2:4, :], axis=0, keepdims=True))[0], + normalize(np.mean(normalize(points)[4:, :], axis=0, + keepdims=True))[0] + ], + dtype=np.float32) + true_assignments = [0] * 2 + [1] * 2 + [2] * 8 + true_score = len(points) - np.tensordot( + normalize(points), true_centers[true_assignments]) + + kmeans = kmeans_lib.KMeansClustering( + 3, + initial_clusters=self.initial_clusters, + distance_metric=kmeans_lib.KMeansClustering.COSINE_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + config=self.config(3)) + kmeans.train( + input_fn=lambda: (constant_op.constant(points), None), steps=30) + + centers = normalize(kmeans.cluster_centers()) + self.assertAllClose( + sorted(centers.tolist()), sorted(true_centers.tolist()), atol=1e-2) + + def _input_fn(): + return (input_lib.limit_epochs( + constant_op.constant(points), num_epochs=1), None) + + assignments = list(kmeans.predict_cluster_index(input_fn=_input_fn)) + self.assertAllClose( + centers[assignments], true_centers[true_assignments], atol=1e-2) + + score = kmeans.score(input_fn=lambda: (constant_op.constant(points), None)) + self.assertAllClose(score, true_score, atol=1e-2) + + +class MiniBatchKMeansCosineTest(KMeansCosineDistanceTest): + + @property + def batch_size(self): + return 2 + + @property + def use_mini_batch(self): + return True + + +class FullBatchAsyncKMeansCosineTest(KMeansCosineDistanceTest): + + @property + def batch_size(self): + return 2 + + @property + def use_mini_batch(self): + return True + + @property + def mini_batch_steps_per_iteration(self): + return self.num_points // self.batch_size + + +class KMeansBenchmark(benchmark.Benchmark): + """Base class for benchmarks.""" + + def SetUp(self, + dimension=50, + num_clusters=50, + points_per_cluster=10000, + center_norm=500, + cluster_width=20): + np.random.seed(123456) + self.num_clusters = num_clusters + self.num_points = num_clusters * points_per_cluster + self.centers = make_random_centers( + self.num_clusters, dimension, center_norm=center_norm) + self.points, _, scores = make_random_points( + self.centers, self.num_points, max_offset=cluster_width) + self.score = float(np.sum(scores)) + + def _report(self, num_iters, start, end, scores): + print(scores) + self.report_benchmark( + iters=num_iters, + wall_time=(end - start) / num_iters, + extras={'true_sum_squared_distances': self.score, + 'fit_scores': scores}) + + def _fit(self, num_iters=10): + pass + + def benchmark_01_2dim_5center_500point(self): + self.SetUp(dimension=2, num_clusters=5, points_per_cluster=100) + self._fit() + + def benchmark_02_20dim_20center_10kpoint(self): + self.SetUp(dimension=20, num_clusters=20, points_per_cluster=500) + self._fit() + + def benchmark_03_100dim_50center_50kpoint(self): + self.SetUp(dimension=100, num_clusters=50, points_per_cluster=1000) + self._fit() + + def benchmark_03_100dim_50center_50kpoint_unseparated(self): + self.SetUp( + dimension=100, + num_clusters=50, + points_per_cluster=1000, + cluster_width=250) + self._fit() + + def benchmark_04_100dim_500center_500kpoint(self): + self.SetUp(dimension=100, num_clusters=500, points_per_cluster=1000) + self._fit(num_iters=4) + + def benchmark_05_100dim_500center_500kpoint_unseparated(self): + self.SetUp( + dimension=100, + num_clusters=500, + points_per_cluster=1000, + cluster_width=250) + self._fit(num_iters=4) + + +class TensorflowKMeansBenchmark(KMeansBenchmark): + + def _fit(self, num_iters=10): + scores = [] + start = time.time() + for i in range(num_iters): + print('Starting tensorflow KMeans: %d' % i) + tf_kmeans = kmeans_lib.KMeansClustering( + self.num_clusters, + initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT, + kmeans_plus_plus_num_retries=int(math.log(self.num_clusters) + 2), + random_seed=i * 42, + relative_tolerance=1e-6, + config=self.config(3)) + tf_kmeans.train( + input_fn=lambda: (constant_op.constant(self.points), None), steps=50) + _ = tf_kmeans.cluster_centers() + scores.append( + tf_kmeans.score( + input_fn=lambda: (constant_op.constant(self.points), None))) + self._report(num_iters, start, time.time(), scores) + + +class SklearnKMeansBenchmark(KMeansBenchmark): + + def _fit(self, num_iters=10): + scores = [] + start = time.time() + for i in range(num_iters): + print('Starting sklearn KMeans: %d' % i) + sklearn_kmeans = SklearnKMeans( + n_clusters=self.num_clusters, + init='k-means++', + max_iter=50, + n_init=1, + tol=1e-4, + random_state=i * 42) + sklearn_kmeans.train(self.points) + scores.append(sklearn_kmeans.inertia_) + self._report(num_iters, start, time.time(), scores) + + +class KMeansTestQueues(test.TestCase): + + def input_fn(self): + + def _fn(): + queue = data_flow_ops.FIFOQueue( + capacity=10, dtypes=dtypes.float32, shapes=[10, 3]) + enqueue_op = queue.enqueue(array_ops.zeros([10, 3], dtype=dtypes.float32)) + queue_runner.add_queue_runner( + queue_runner.QueueRunner(queue, [enqueue_op])) + return queue.dequeue(), None + + return _fn + + # This test makes sure that there are no deadlocks when using a QueueRunner. + # Note that since cluster initialization is dependendent on inputs, if input + # is generated using a QueueRunner, one has to make sure that these runners + # are started before the initialization. + def test_queues(self): + kmeans = kmeans_lib.KMeansClustering(5) + kmeans.train(input_fn=self.input_fn(), steps=1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 3e3ee5fa57f1356db98a17f9e17e60f01d85d3b9..3976395d78e9188dd56d5b3b32fa8a3daf43c37d 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -26,7 +26,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -38,31 +37,30 @@ from tensorflow.python.training import session_run_hook class _SweepHook(session_run_hook.SessionRunHook): """Keeps track of row/col sweeps, and runs prep ops before each sweep.""" - def __init__(self, is_row_sweep_var, train_op, num_rows, num_cols, - processed_row_indices, processed_col_indices, row_prep_ops, - col_prep_ops, cache_init_ops, completed_sweeps_var): + def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols, + input_row_indices, input_col_indices, row_prep_ops, + col_prep_ops, init_op, completed_sweeps_var): """Initializes SweepHook. Args: is_row_sweep_var: A Boolean tf.Variable, determines whether we are currently doing a row or column sweep. It is updated by the hook. - train_op: An op. All the ops created by the hook will have - control_dependencies on train_op. + train_ops: A list of ops. The ops created by this hook will have + control dependencies on `train_ops`. num_rows: int, the total number of rows to be processed. num_cols: int, the total number of columns to be processed. - processed_row_indices: A Tensor of type int64. The indices of the input - rows that are processed during the current sweep. All elements of - processed_row_indices must be in [0, num_rows). - processed_col_indices: A Tensor of type int64. The indices of the input + input_row_indices: A Tensor of type int64. The indices of the input rows + that are processed during the current sweep. All elements of + `input_row_indices` must be in [0, num_rows). + input_col_indices: A Tensor of type int64. The indices of the input columns that are processed during the current sweep. All elements of - processed_col_indices must be in [0, num_cols). + `input_col_indices` must be in [0, num_cols). row_prep_ops: list of ops, to be run before the beginning of each row sweep, in the given order. col_prep_ops: list of ops, to be run before the beginning of each column sweep, in the given order. - cache_init_ops: list of ops, to be run once before training, in the given - order. These are typically local initialization ops (such as cache - initialization). + init_op: op to be run once before training. This is typically a local + initialization op (such as cache initialization). completed_sweeps_var: An integer tf.Variable, indicates the number of completed sweeps. It is updated by the hook. """ @@ -70,55 +68,45 @@ class _SweepHook(session_run_hook.SessionRunHook): self._num_cols = num_cols self._row_prep_ops = row_prep_ops self._col_prep_ops = col_prep_ops - self._cache_init_ops = cache_init_ops + self._init_op = init_op self._is_row_sweep_var = is_row_sweep_var self._completed_sweeps_var = completed_sweeps_var - # Boolean variable that determines whether the cache_init_ops have been run. + # Boolean variable that determines whether the init_ops have been run. self._is_initialized = False - # Boolean variable that is set to True when a sweep is completed. - # Used to run the prep_ops at the beginning of a sweep, in before_run(). - self._is_sweep_done = False - # Ops to run jointly with train_op, responsible for updating - # _is_row_sweep_var and incrementing the global_step and completed_sweeps - # counters. They have control_dependencies on train_op. - self._fetches = self._create_switch_ops(processed_row_indices, - processed_col_indices, train_op) - - def _create_switch_ops(self, processed_row_indices, processed_col_indices, - train_op): + # Ops to run jointly with train_ops, responsible for updating + # `is_row_sweep_var` and incrementing the `global_step` and + # `completed_sweeps` counters. + self._update_op, self._is_sweep_done_var, self._switch_op = ( + self._create_hook_ops(input_row_indices, input_col_indices, train_ops)) + + def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops): """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. - Creates two boolean tensors processed_rows and processed_cols, which keep - track of which rows/cols have been processed during the current sweep. + Creates two boolean tensors `processed_rows` and `processed_cols`, which + keep track of which rows/cols have been processed during the current sweep. Returns ops that should be run after each row / col update. - - When is_row_sweep_var is True, it sets - processed_rows[processed_row_indices] to True. - - When is_row_sweep_var is False, it sets - processed_cols[processed_col_indices] to True . - When all rows or all cols have been processed, negates is_row_sweep_var, - increments the completed_sweeps counter, and resets processed_rows and - processed_cols to False. - All of the ops created by this function have control_dependencies on - train_op. + - When `self._is_row_sweep_var` is True, it sets + processed_rows[input_row_indices] to True. + - When `self._is_row_sweep_var` is False, it sets + processed_cols[input_col_indices] to True. Args: - processed_row_indices: A Tensor. The indices of the input rows that are + input_row_indices: A Tensor. The indices of the input rows that are processed during the current sweep. - processed_col_indices: A Tensor. The indices of the input columns that + input_col_indices: A Tensor. The indices of the input columns that are processed during the current sweep. - train_op: An op. All the ops created by this function have - control_dependencies on train_op. + train_ops: A list of ops. The ops created by this function have control + dependencies on `train_ops`. + Returns: - A list consisting of: - is_sweep_done: A Boolean tensor, determines whether the sweep is done, - i.e. all rows (during a row sweep) or all columns (during a column - sweep) have been processed. - switch_ops: An op that updates is_row_sweep_var when is_sweep_done is - True. Has control_dependencies on train_op. - incr_ops: An op that increments the global_step and completed_sweeps - counters. Has control_dependenciens on switch_ops. + A tuple consisting of: + update_op: An op to be run jointly with training. It updates the state + and increments counters (global step and completed sweeps). + is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is + done, i.e. all rows (during a row sweep) or all columns (during a + column sweep) have been processed. + switch_op: An op to be run in `self.before_run` when the sweep is done. """ - processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) with ops.colocate_with(processed_rows_init): processed_rows = variable_scope.variable( @@ -133,97 +121,72 @@ class _SweepHook(session_run_hook.SessionRunHook): collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="sweep_hook_processed_cols") - # After running the train_op, update processed_rows or processed_cols - # tensors, depending on whether we are currently doing a row or a col sweep - with ops.control_dependencies([train_op]): - - def get_row_update_op(): - with ops.colocate_with(processed_rows): - return state_ops.scatter_update(processed_rows, processed_row_indices, - array_ops.ones_like( - processed_row_indices, - dtype=dtypes.bool)) - - def get_col_update_op(): - with ops.colocate_with(processed_cols): - return state_ops.scatter_update(processed_cols, processed_col_indices, - array_ops.ones_like( - processed_col_indices, - dtype=dtypes.bool)) - - update_processed_op = control_flow_ops.cond( - self._is_row_sweep_var, get_row_update_op, get_col_update_op) - - # After update_processed_op, check whether we have completed a sweep. - # If this is the case, flip the is_row_sweep_var and reset processed_rows - # and processed_cols tensors. - with ops.control_dependencies([update_processed_op]): - - def get_switch_op(): - return state_ops.assign( - self._is_row_sweep_var, - gen_math_ops.logical_not(self._is_row_sweep_var)).op - - def get_reset_op(): - return control_flow_ops.group( - state_ops.assign(processed_rows, processed_rows_init).op, - state_ops.assign(processed_cols, processed_cols_init).op) - - is_sweep_done = control_flow_ops.cond( + switch_ops = control_flow_ops.group( + state_ops.assign( self._is_row_sweep_var, - lambda: math_ops.reduce_all(processed_rows), - lambda: math_ops.reduce_all(processed_cols), - name="sweep_hook_is_sweep_done") - switch_op = control_flow_ops.cond( - is_sweep_done, - get_switch_op, - control_flow_ops.no_op, - name="sweep_hook_switch_op") - reset_op = control_flow_ops.cond( - is_sweep_done, - get_reset_op, - control_flow_ops.no_op, - name="sweep_hook_reset_op") - switch_ops = control_flow_ops.group( - switch_op, reset_op, name="sweep_hook_switch_ops") - - with ops.control_dependencies([switch_ops]): - # Op to increment the completed_sweeps counter. - completed_sweeps_incr_op = control_flow_ops.cond( - is_sweep_done, - lambda: state_ops.assign_add(self._completed_sweeps_var, 1).op, - control_flow_ops.no_op, - name="completed_sweeps_incr") - - # Op to increment the global_step counter. - global_step = framework_variables.get_global_step() - if global_step is not None: - global_step_incr_op = state_ops.assign_add( - global_step, 1, name="global_step_incr").op - else: - global_step_incr_op = control_flow_ops.no_op( - name="global_step_incr") - - incr_ops = control_flow_ops.group( - completed_sweeps_incr_op, - global_step_incr_op, - name="counter_incr_ops") - - return [is_sweep_done, switch_ops, incr_ops] + math_ops.logical_not(self._is_row_sweep_var)), + state_ops.assign(processed_rows, processed_rows_init), + state_ops.assign(processed_cols, processed_cols_init)) + is_sweep_done_var = variable_scope.variable( + False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False, + name="is_sweep_done") + + # After running the `train_ops`, updates `processed_rows` or + # `processed_cols` tensors, depending on whether this is a row or col sweep. + with ops.control_dependencies(train_ops): + with ops.colocate_with(processed_rows): + update_processed_rows = state_ops.scatter_update( + processed_rows, + input_row_indices, + math_ops.logical_and( + self._is_row_sweep_var, + array_ops.ones_like(input_row_indices, dtype=dtypes.bool))) + with ops.colocate_with(processed_cols): + update_processed_cols = state_ops.scatter_update( + processed_cols, + input_col_indices, + math_ops.logical_and( + math_ops.logical_not(self._is_row_sweep_var), + array_ops.ones_like(input_col_indices, dtype=dtypes.bool))) + update_processed_op = control_flow_ops.group( + update_processed_rows, update_processed_cols) - def begin(self): - pass + with ops.control_dependencies([update_processed_op]): + is_sweep_done = math_ops.logical_or( + math_ops.reduce_all(processed_rows), + math_ops.reduce_all(processed_cols)) + # Increments global step. + global_step = framework_variables.get_global_step() + if global_step is not None: + global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + global_step_incr_op = control_flow_ops.no_op() + # Increments completed sweeps. + completed_sweeps_incr_op = state_ops.assign_add( + self._completed_sweeps_var, + math_ops.cast(is_sweep_done, dtypes.int32), + use_locking=True).op + update_ops = control_flow_ops.group( + global_step_incr_op, + completed_sweeps_incr_op, + state_ops.assign(is_sweep_done_var, is_sweep_done)) + + return update_ops, is_sweep_done_var, switch_ops def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" - # Run the appropriate cache_init and prep ops + # Runs the appropriate init ops and prep ops. sess = run_context.session + is_sweep_done = sess.run(self._is_sweep_done_var) if not self._is_initialized: - logging.info("SweepHook running cache init ops.") - for init_op in self._cache_init_ops: - sess.run(init_op) - - if self._is_sweep_done or not self._is_initialized: + logging.info("SweepHook running cache init op.") + sess.run(self._init_op) + if is_sweep_done: + sess.run(self._switch_op) + if is_sweep_done or not self._is_initialized: logging.info("SweepHook running sweep prep ops.") row_sweep = sess.run(self._is_row_sweep_var) prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops @@ -232,13 +195,12 @@ class _SweepHook(session_run_hook.SessionRunHook): self._is_initialized = True - # Request running the switch_ops and the incr_ops - logging.info("Partial fit starting.") - return session_run_hook.SessionRunArgs(fetches=self._fetches) + # Requests running `self._update_op` jointly with the training op. + logging.info("Next fit step starting.") + return session_run_hook.SessionRunArgs(fetches=[self._update_op]) def after_run(self, run_context, run_values): - self._is_sweep_done = run_values.results[0] - logging.info("Partial fit done.") + logging.info("Fit step done.") class _StopAtSweepHook(session_run_hook.SessionRunHook): @@ -360,19 +322,19 @@ def _wals_factorization_model_function(features, labels, mode, params): col_prep_ops = [ model.col_update_prep_gramian_op, model.initialize_col_update_op ] - cache_init_ops = [model.worker_init] + init_ops = [model.worker_init] sweep_hook = _SweepHook( is_row_sweep_var, - train_op, + [train_op, loss], params["num_rows"], params["num_cols"], input_row_indices, input_col_indices, row_prep_ops, col_prep_ops, - cache_init_ops, - completed_sweeps_var,) + init_ops, + completed_sweeps_var) training_hooks = [sweep_hook] if max_sweeps is not None: training_hooks.append(_StopAtSweepHook(max_sweeps)) diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index b5c1bb1151e78a8f19d3c91b57ef3bfd6152893d..8bd72b7025aad80e387171b93b9b264da3ed0f66 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -357,7 +357,7 @@ class WALSMatrixFactorizationTest(test.TestCase): self.assertNear( loss, true_loss, err=.001, - msg="""After row update, eval loss = {}, does not match the true + msg="""After col update, eval loss = {}, does not match the true loss = {}.""".format(loss, true_loss)) @@ -442,7 +442,7 @@ class SweepHookTest(test.TestCase): completed_sweeps_var = variables.Variable(0) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, - self._train_op, + [self._train_op], self._num_rows, self._num_cols, self._input_row_indices_ph, @@ -465,11 +465,9 @@ class SweepHookTest(test.TestCase): 'False.') # Row sweep completed. mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) - self.assertFalse(sess.run(is_row_sweep_var), - msg='Row sweep is complete but is_row_sweep is True.') self.assertTrue(sess.run(completed_sweeps_var) == 1, msg='Completed sweeps should be equal to 1.') - self.assertTrue(sweep_hook._is_sweep_done, + self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is complete but is_sweep_done is False.') # Col init ops should run. Col sweep not completed. mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4])) @@ -478,13 +476,11 @@ class SweepHookTest(test.TestCase): self.assertFalse(sess.run(is_row_sweep_var), msg='Col sweep is not complete but is_row_sweep is ' 'True.') - self.assertFalse(sweep_hook._is_sweep_done, + self.assertFalse(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is not complete but is_sweep_done is True.') # Col sweep completed. mon_sess.run(self._train_op, ind_feed([], [4, 5, 6])) - self.assertTrue(sess.run(is_row_sweep_var), - msg='Col sweep is complete but is_row_sweep is False') - self.assertTrue(sweep_hook._is_sweep_done, + self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is complete but is_sweep_done is False.') self.assertTrue(sess.run(completed_sweeps_var) == 2, msg='Completed sweeps should be equal to 2.') diff --git a/tensorflow/contrib/ffmpeg/default/BUILD b/tensorflow/contrib/ffmpeg/default/BUILD index 05fc658d80f26b00f775211cf89f55ce18a4502d..949ae9ad9e4b045ee1b5cc82d49c0e7468c2005d 100644 --- a/tensorflow/contrib/ffmpeg/default/BUILD +++ b/tensorflow/contrib/ffmpeg/default/BUILD @@ -23,6 +23,18 @@ cc_library( ], ) +tf_cc_test( + name = "ffmpeg_lib_utility_test", + srcs = ["ffmpeg_lib_utility_test.cc"], + deps = [ + ":ffmpeg_lib", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "ffmpeg_lib_installed_test", srcs = ["ffmpeg_lib_test.cc"], diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index b417a70b6e63310a5a1d9a82522cd5e678e7b6b0..545a4386d043af604a747b8b5a8103101812b177 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -198,6 +198,14 @@ string BuildWavFile(int32 samples_per_second, int32 channel_count, return data; } +// Returns a unique number every time it is called. +int64 UniqueId() { + static mutex mu(LINKER_INITIALIZED); + static int64 id = 0; + mutex_lock l(mu); + return ++id; +} + } // namespace string GetTempFilename(const string& extension) { @@ -208,8 +216,12 @@ string GetTempFilename(const string& extension) { } struct stat statbuf; if (!stat(dir, &statbuf) && S_ISDIR(statbuf.st_mode)) { - string tmp_filepath = - io::JoinPath(dir, StrCat("tmp_file_XXXXXX", ".", extension)); + // UniqueId is added here because mkstemps is not as thread safe as it + // looks. https://github.com/tensorflow/tensorflow/issues/5804 shows + // the problem. + string tmp_filepath = io::JoinPath( + dir, + StrCat("tmp_file_tensorflow_", UniqueId(), "_XXXXXX.", extension)); int fd = mkstemps(&tmp_filepath[0], extension.length() + 1); if (fd < 0) { LOG(FATAL) << "Failed to create temp file."; diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7176f3b550679555d5ab3b70f2b360a90eaee253 --- /dev/null +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc @@ -0,0 +1,80 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h" + +#include +#include +#include +#include + +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace ffmpeg { +namespace { + +TEST(FfmpegLibTest, TestTempDirectoryThreading) { + // Testing a fix for a bug that allowed different threads to create + // conflicting temp files. + // See github.com/tensorflow/tensorflow/issues/5804 for details. + const int32 kNumThreads = 10; + const int32 kNumWorkItems = 10000; + static constexpr size_t kStringsPerItem = 100; + Env* environment = Env::Default(); + thread::ThreadPool pool(environment, "test", kNumThreads); + + mutex mu; + std::vector temp_filenames; + temp_filenames.reserve(kNumWorkItems * kStringsPerItem); + + // Queue a large number of work items for the threads to process. Each work + // item creates a temp file and then deletes it. + for (int i = 0; i < kNumWorkItems; ++i) { + pool.Schedule([&mu, &temp_filenames, environment]() { + std::array buffer; + for (int32 j = 0; j < kStringsPerItem; ++j) { + buffer[j] = GetTempFilename("mp3"); + TF_QCHECK_OK(environment->DeleteFile(buffer[j])); + } + mutex_lock l(mu); + for (const auto& fn : buffer) { + temp_filenames.push_back(fn); + } + }); + } + + // Wait until all work items are complete. + while (true) { + mutex_lock l(mu); + if (temp_filenames.size() == kNumWorkItems * kStringsPerItem) { + break; + } + } + + // Check that no duplicates are created. + std::set unique_filenames; + mutex_lock l(mu); + for (const auto& fn : temp_filenames) { + ASSERT_TRUE(unique_filenames.insert(fn).second); + } +} + +} // namespace +} // namespace ffmpeg +} // namespace tensorflow diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 6b0599ddd2def8dd698a1bd152b5be926c6ddf3e..90aed3065b1e8238886820698260eba049017042 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -10,9 +10,8 @@ package(default_visibility = [ "//tensorflow:__subpackages__", ]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") @@ -27,6 +26,7 @@ tf_custom_op_py_library( "python/framework/experimental.py", "python/framework/tensor_util.py", "python/ops/__init__.py", + "python/ops/accumulate_n_v2.py", "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", @@ -149,6 +149,31 @@ py_test( ], ) +py_test( + name = "accumulate_n_v2_test", + size = "small", + srcs = ["python/ops/accumulate_n_v2_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) + +py_test( + name = "accumulate_n_v2_eager_test", + size = "small", + srcs = ["python/ops/accumulate_n_v2_eager_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python/eager:backprop", + ], +) + py_test( name = "ops_test", size = "small", @@ -214,7 +239,6 @@ py_test( deps = [ ":framework_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", @@ -222,6 +246,7 @@ py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:partitioned_variables", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", @@ -254,7 +279,6 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:partitioned_variables", diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 2081a11f47d71106f8e57227f46639717a791855..8421ba7c0423c6ed274f92ba74930822d0171e05 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -37,6 +37,7 @@ See the @{$python/contrib.framework} guide. @@arg_scope @@add_arg_scope +@@current_arg_scope @@has_arg_scope @@arg_scoped_arguments diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py index 9681a03767dadac655fe8c4758f960349ad10cf4..92a2a4ff2d1cb41c48312038d82be0b6136f8d41 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util.py @@ -78,9 +78,9 @@ def reduce_sum_n(tensors, name=None): return math_ops.add_n(tensors, name=name_scope) @deprecated(None, - "Please switch to tf.confusion_matrix.remove_squeezable_dimensions. Note " - "that order of the inputs and outputs of labels and predictions have also " - "been switched.") + 'Please switch to tf.confusion_matrix.remove_squeezable_dimensions.' + 'Note that order of the inputs and outputs of labels and ' + 'predictions have also been switched.') def remove_squeezable_dimensions(predictions, labels, name=None): """Squeeze last dim if ranks of `predictions` and `labels` differ by 1. diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..a0667bd489213cf366e27114a91e8699ed9e7428 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py @@ -0,0 +1,111 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops that will eventually be folded into tensorflow/python/ops/math_ops.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops + + + +def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): + """Returns the element-wise sum of a list of tensors. + + Optionally, pass `shape` and `tensor_dtype` for shape and type checking, + otherwise, these are inferred. + + `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not + wait for all of its inputs to be ready before beginning to sum. This can + save memory if inputs are ready at different times, since minimum temporary + storage is proportional to the output size rather than the inputs size. + + Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. + + For example: + + ```python + a = tf.constant([[1, 2], [3, 4]]) + b = tf.constant([[5, 0], [0, 6]]) + tf.accumulate_n_v2([a, b, a]) # [[7, 4], [6, 14]] + + # Explicitly pass shape and type + tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) + # [[7, 4], + # [6, 14]] + ``` + + Args: + inputs: A list of `Tensor` objects, each with same shape and type. + shape: Shape of elements of `inputs`. + tensor_dtype: The type of `inputs`. + name: A name for the operation (optional). + + Returns: + A `Tensor` of same shape and type as the elements of `inputs`. + + Raises: + ValueError: If `inputs` don't all have same shape and dtype or the shape + cannot be inferred. + """ + _INPUTS_ERR_MSG = ValueError("inputs must be a list of at least one Tensor" + "with the same dtype and shape") + if not inputs or not isinstance(inputs, (list, tuple)): + raise _INPUTS_ERR_MSG + inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) + if not all(isinstance(x, ops.Tensor) for x in inputs): + raise _INPUTS_ERR_MSG + if not all(x.dtype == inputs[0].dtype for x in inputs): + raise _INPUTS_ERR_MSG + if shape is not None: + shape = tensor_shape.as_shape(shape) + else: + shape = tensor_shape.unknown_shape() + for input_tensor in inputs: + if isinstance(input_tensor, ops.Tensor): + shape = shape.merge_with(input_tensor.get_shape()) + + # tensor_dtype is for safety only; operator's output type computed in C++ + if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: + raise TypeError("tensor_dtype is {}, but input is of type {}" + .format(tensor_dtype, inputs[0].dtype)) + + if len(inputs) == 1 and name is None: + return inputs[0] + elif len(inputs) == 1 and name is not None: + return array_ops.identity(inputs[0], name=name) + elif context.in_eager_mode(): + # TemporaryVariable not currently supported in eager mode; fall back + # onto AddN for now. + # TODO(frreiss) remove this once the lifetime of eager variables gets + # addressed + return math_ops.add_n(inputs, name=name) + else: + return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) + +# The following code should eventually be merged into +# tensorflow/python/ops/math_grad.py +@ops.RegisterGradient("AccumulateNV2") +def _AddNGrad(op, grad): + """Same as gradient for AddN. Copies the gradient to all inputs.""" + # Not broadcasting. + return [grad] * len(op.inputs) + diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c2229bb8ad3d5b38321d16f150ed94175ab9bdbe --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py @@ -0,0 +1,85 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for new version of accumulate_n op that will eventually go into +`ops.math_ops`. + +These test cases spefically exercise the `eager` APIs. They need to be in a +separate file from the remaining tests because eager mode is currently something +you can turn on but can't turn off for the lifetime of the current process.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 + +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context as eager_context +from tensorflow.python.eager import tape + + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test + + + +class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): + """Tests of the new, differentiable version of accumulate_n""" + + def testMinimalEagerMode(self): + forty = constant_op.constant(40) + two = constant_op.constant(2) + answer = av2.accumulate_n_v2([forty, two]) + self.assertEqual(42, answer.numpy()) + + + def testFloat(self): + np.random.seed(12345) + x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] + tf_x = ops.convert_n_to_tensor(x) + with self.test_session(use_gpu=True): + self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).numpy()) + self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).numpy()) + + def testGrad(self): + np.random.seed(42) + num_inputs = 3 + input_vars = [ + resource_variable_ops.ResourceVariable(10.0 * np.random.random(), + name="t%d" % i) + for i in range(0, num_inputs) + ] + + def fn(first, second, third): + return av2.accumulate_n_v2([first, second, third]) + + grad_fn = backprop.gradients_function(fn) + grad = grad_fn(input_vars[0], input_vars[1], input_vars[2]) + self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + [elem.numpy() for elem in grad]) + + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() + diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3386e849d5cb8516ab3b1f6cb0429be3fc2fc960 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py @@ -0,0 +1,123 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for new version of accumulate_n op that will eventually go into +`ops.math_ops`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gradients +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + + +class AccumulateNV2Test(test_util.TensorFlowTestCase): + """Tests of the new, differentiable version of accumulate_n""" + + def testFloat(self): + np.random.seed(12345) + x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] + tf_x = ops.convert_n_to_tensor(x) + with self.test_session(use_gpu=True): + self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).eval()) + self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).eval()) + + def testInt(self): + np.random.seed(54321) + x = [np.random.randint(-128, 128, (5, 4, 3, 2, 1)) for _ in range(6)] + tf_x = ops.convert_n_to_tensor(x) + with self.test_session(use_gpu=True): + self.assertAllEqual(sum(x), av2.accumulate_n_v2(tf_x).eval()) + self.assertAllEqual(x[0] * 6, av2.accumulate_n_v2([tf_x[0]] * 6).eval()) + + def testGrad(self): + np.random.seed(42) + for num_inputs in range(1, 10): + with self.test_session(use_gpu=True) as sess: + input_vars = [ + variables.Variable(10.0 * np.random.random()) + for i in range(0, num_inputs) + ] + accum_n = av2.accumulate_n_v2(input_vars) + sess.run(variables.global_variables_initializer()) + accum_n_grad = gradients.gradients(accum_n, input_vars) + self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + [g.eval() for g in accum_n_grad]) + + # The tests below used to be in a separate class under cwise_ops_test.py, + # which did not run in the default test target. + # Putting them here so that everything that exercises AccumulateNV2 is in + # one place and the default build runs all unit tests. + def testSimple(self): + with self.test_session(): + random_arrays = [ + np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20) + ] + random_tensors = [ + ops.convert_to_tensor( + x, dtype=dtypes_lib.float32) for x in random_arrays + ] + tf_val = av2.accumulate_n_v2(random_tensors) + np_val = random_arrays[0] + for random_array in random_arrays[1:]: + np_val += random_array + self.assertAllClose(np_val, tf_val.eval()) + + def testZeroArgs(self): + with self.test_session(): + with self.assertRaises(ValueError): + tf_val = av2.accumulate_n_v2([]) + tf_val.eval() + + def testWrongShape(self): + with self.test_session(): + with self.assertRaises(ValueError): + a = variables.Variable(0.2) + b = variables.Variable(0.1) + tf_val = av2.accumulate_n_v2([a,b], shape=[2,2]) # Should be shape=[] + + def testIncompatibleShapes(self): + with self.test_session(): + with self.assertRaises(ValueError): + a = variables.Variable(np.array([0.1,0.2])) + b = variables.Variable(np.array([[0.3],[0.4]])) + tf_val = av2.accumulate_n_v2([a,b]) + + def testWrongType(self): + with self.test_session(): + with self.assertRaises(TypeError): + a = variables.Variable(0.2, dtype=np.float32) + b = variables.Variable(0.1, dtype=np.float32) + tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) + + def testWrongTypeOneInput(self): + # Scenario that used to trigger a bug, even when testWrongType() worked + with self.test_session(): + with self.assertRaises(TypeError): + a = variables.Variable(0.2, dtype=np.float32) + tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 9c194ec202ab6150278b26e844b9d3e97a7d6761..2bce00fde2459878a12027bb4d98bd3818bc92a2 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -67,6 +67,7 @@ from tensorflow.python.util import tf_decorator __all__ = ['arg_scope', 'add_arg_scope', + 'current_arg_scope', 'has_arg_scope', 'arg_scoped_arguments'] @@ -83,7 +84,7 @@ def _get_arg_stack(): return _ARGSTACK -def _current_arg_scope(): +def current_arg_scope(): stack = _get_arg_stack() return stack[-1] @@ -144,7 +145,7 @@ def arg_scope(list_ops_or_scope, **kwargs): raise TypeError('list_ops_or_scope must either be a list/tuple or reused' 'scope (i.e. dict)') try: - current_scope = _current_arg_scope().copy() + current_scope = current_arg_scope().copy() for op in list_ops_or_scope: key_op = _key_op(op) if not has_arg_scope(op): @@ -172,7 +173,7 @@ def add_arg_scope(func): A tuple with the decorated function func_with_args(). """ def func_with_args(*args, **kwargs): - current_scope = _current_arg_scope() + current_scope = current_arg_scope() current_args = kwargs key_func = _key_op(func) if key_func in current_scope: diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 256f2008687bc97a6e897c5d833ebb2d559383e0..88306094ab9947c9c78b03c0013f6afc88316803 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -298,6 +298,17 @@ void LaunchFusedConv2DBiasActivationOp:: constexpr int rank = is_int8x4 ? 5 : 4; constexpr int vect = is_int8x4 ? 4 : 1; + if (is_int8x4) { + int cc_major, cc_minor; + stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + OP_REQUIRES( + ctx, cc_major >= 6 && cc_minor >= 1, + errors::Unimplemented( + "FusedConv2DBiasActivation for int8 is only supported on GPUs with " + "compute capability 6.1 or later.")); + } + const int batch_size = GetTensorDim(conv_input_param, data_format, 'N'); int conv_input_rows = GetTensorDim(conv_input_param, data_format, 'H'); int conv_input_cols = GetTensorDim(conv_input_param, data_format, 'W'); @@ -434,11 +445,11 @@ void LaunchFusedConv2DBiasActivationOp:: .set_zero_padding_width(padding_cols / 2); Tensor maybe_transformed_filter; - const Tensor* filter; - if (is_int8x4) { - // We have already checked filter is OIHW_VECT_I in the constructor. - filter = &filter_param; - } else if (filter_format == FORMAT_HWIO) { + const Tensor* filter = &filter_param; + // For qint8, we have already checked filter is OIHW_VECT_I in the + // constructor, but we need to test for is_int8x4 so the if block doesn't + // generate code for qint8. + if (!is_int8x4 && filter_format == FORMAT_HWIO) { // Shuffle filter tensor from HWIO to OIHW: OP_REQUIRES_OK(ctx, ctx->allocate_temp( DataTypeToEnum::value, diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 3b8f7d6ed760647c4c61ce5ea60be1d7d17ddfa0..2a18f3eeecc7e0e69c54b219886a263136f01b2c 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -159,9 +159,12 @@ class FusedConv2DBiasActivationTest(test.TestCase): def _DtypesToTest(self, use_gpu): return [dtypes.float32] + def _FilterFormatsToTest(self, use_gpu): + return ["HWIO", "OIHW"] + def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, bias, strides, padding, activation_mode, data_format, - dtype): + filter_format, dtype): """Verifies the output values of the convolution function. Args: @@ -174,6 +177,7 @@ class FusedConv2DBiasActivationTest(test.TestCase): padding: Padding type. activation_mode: Activation mode. data_format: Format of the data tensors. + filter_format: Filter format to use for the fused convolution. dtype: Data type for inputs and outputs. Returns: Symbolic tensor value and reference value that can be used to @@ -192,6 +196,9 @@ class FusedConv2DBiasActivationTest(test.TestCase): with self.test_session(use_gpu=True): t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) + fused_t2 = t2 + if filter_format == "OIHW": + fused_t2 = HwioToOihw(t2) t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype) strides = [1] + strides + [1] if data_format == "NCHW": @@ -199,11 +206,12 @@ class FusedConv2DBiasActivationTest(test.TestCase): strides = test_util.NHWCToNCHW(strides) output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( t1, - t2, + fused_t2, t3, strides=strides, padding=padding, data_format=data_format, + filter_format=filter_format, activation_mode=activation_mode) ref_conv_output = nn_ops.conv2d( t1, t2, strides=strides, padding=padding, data_format=data_format) @@ -268,9 +276,10 @@ class FusedConv2DBiasActivationTest(test.TestCase): ref_tensors = [] for (data_format, use_gpu) in GetTestConfigs(): for dtype in self._DtypesToTest(use_gpu): - result, expected = self._SetupValuesForDevice( - tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu", - data_format, dtype) + for filter_format in self._FilterFormatsToTest(use_gpu): + result, expected = self._SetupValuesForDevice( + tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu", + data_format, filter_format, dtype) tensors.append(result) ref_tensors.append(expected) with self.test_session() as sess: @@ -607,6 +616,10 @@ def NchwToNchwVectC(in_tensor): return array_ops.transpose(t, [0, 1, 3, 4, 2]) +def HwioToOihw(in_tensor): + return array_ops.transpose(in_tensor, [3, 2, 0, 1]) + + def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, padding, strides, side_input_scale, side_input, biases): diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 64bff7cecf5691800b91366bcd10ad1e3449d1b5..1418c87023af0dbff890f46e10f0140d5b89e4b7 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -202,6 +202,7 @@ py_library( "//tensorflow/python:embedding_ops", "//tensorflow/python:math_ops", "//tensorflow/python:tensor_util", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", ], ) @@ -234,6 +235,7 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", ], ) @@ -267,7 +269,10 @@ py_library( "python/features/python/clip_weights_impl.py", ], srcs_version = "PY2AND3", - deps = ["//tensorflow/contrib/opt:opt_py"], + deps = [ + "//tensorflow/contrib/opt:opt_py", + "//tensorflow/python:util", + ], ) py_test( @@ -441,6 +446,7 @@ py_test( srcs = ["python/estimator/python/gan_estimator_test.py"], shard_count = 1, srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":gan_estimator", ":namedtuples", diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 10458a2458384c8f589183003256db24d69742d7..5d74df3ef70be03060e9cd30249359ee07b4b83a 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -51,9 +51,10 @@ network to evaluate your unconditional generative model. You can also also use your own pretrained classifier for more specific performance numbers, or use other methods for evaluating conditional generative models. -* [examples](https://github.com/tensorflow/models/tree/master/gan/): +* examples (coming soon): See examples of how to use TFGAN to make GAN training easier, or use the more complicated examples to jumpstart your -own project. +own project. These include unconditional and conditional GANs, InfoGANs, +adversarial losses on existing networks, and image-to-image translation. ## Training a GAN model diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 6e1ee730aac51f8ee9f81248c41e97f75f935eff..e89993991a389d68254a95aded2d771f4c2627be 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -238,7 +238,7 @@ def _make_train_gan_model(generator_fn, discriminator_fn, real_data, if add_summaries: if not isinstance(add_summaries, (tuple, list)): add_summaries = [add_summaries] - with ops.name_scope(''): + with ops.name_scope(None): for summary_type in add_summaries: _summary_type_map[summary_type](gan_model) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 3a6456f038b06d6cc352c012d14e7d6ebdfc29c5..d4c080cab3d82f6a69a293e84e1c08322bbb6f86 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -16,6 +16,11 @@ These methods come from https://arxiv.org/abs/1606.03498 and https://arxiv.org/abs/1706.08500. + +NOTE: This implementation uses the same weights as in +https://github.com/openai/improved-gan/blob/master/inception_score/model.py, +but is more numerically stable and is an unbiased estimator of the true +Inception score even when splitting the inputs into batches. """ from __future__ import absolute_import @@ -54,17 +59,16 @@ __all__ = [ 'classifier_score', 'frechet_inception_distance', 'frechet_classifier_distance', + 'INCEPTION_DEFAULT_IMAGE_SIZE', ] -INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v3_2017_09_13.tar.gz' -INCEPTION_FROZEN_GRAPH = 'frozen_inception_v3.pb' -INCEPTION_V3_INPUT = 'input' -INCEPTION_V3_OUTPUT = 'InceptionV3/Logits/SpatialSqueeze:0' -INCEPTION_V3_FINAL_POOL = 'InceptionV3/Logits/AvgPool_1a_8x8/AvgPool:0' -_INCEPTION_V3_NUM_CLASSES = 1001 -_INCEPTION_V3_FINAL_POOL_SIZE = 2048 -INCEPTION_V3_DEFAULT_IMG_SIZE = 299 +INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz' +INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb' +INCEPTION_INPUT = 'Mul:0' +INCEPTION_OUTPUT = 'logits:0' +INCEPTION_FINAL_POOL = 'pool_3:0' +INCEPTION_DEFAULT_IMAGE_SIZE = 299 def _validate_images(images, image_size): @@ -102,46 +106,37 @@ def _symmetric_matrix_square_root(mat, eps=1e-10): math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True) -# Convenience preprocessing function, with fixed defaults. -# NOTE: Floating-point inputs are expected to be in [0, 1]. -# Copied from /tensorflow_models/slim/preprocessing/inception_preprocessing.py. def preprocess_image( - image, height=INCEPTION_V3_DEFAULT_IMG_SIZE, - width=INCEPTION_V3_DEFAULT_IMG_SIZE, central_fraction=0.875, scope=None): - """Prepare one image for evaluation. + images, height=INCEPTION_DEFAULT_IMAGE_SIZE, + width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None): + """Prepare a batch of images for evaluation. - If height and width are specified it would output an image with that size by - applying resize_bilinear. + This is the preprocessing portion of the graph from + http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz. - If central_fraction is specified it would crop the central fraction of the - input image. + Note that it expects Tensors in [0, 255]. This function maps pixel values to + [-1, 1] and resizes to match the InceptionV1 network. Args: - image: 3-D Tensor of image. If dtype is tf.float32 then the range should be - [0, 1], otherwise it would converted to tf.float32 assuming that the range - is [0, MAX], where MAX is largest positive representable number for - int(8/16/32) data type (see `tf.image.convert_image_dtype` for details). - height: integer - width: integer - central_fraction: Optional Float, fraction of the image to crop. + images: 3-D or 4-D Tensor of images. Values are in [0, 255]. + height: Integer. Height of resized output image. + width: Integer. Width of resized output image. scope: Optional scope for name_scope. + Returns: - 3-D float Tensor of prepared image. + 3-D or 4-D float Tensor of prepared image(s). Values are in [-1, 1]. """ - with ops.name_scope(scope, 'eval_image', [image, height, width]): - if image.dtype != dtypes.float32: - image = image_ops.convert_image_dtype(image, dtype=dtypes.float32) - # Crop the central region of the image with an area containing 87.5% of - # the original image. - image = image_ops.central_crop(image, central_fraction=central_fraction) - - # Resize the image to the specified height and width. - image = array_ops.expand_dims(image, 0) - image = image_ops.resize_bilinear(image, [height, width], - align_corners=False) - image = array_ops.squeeze(image, [0]) - image = (image - 0.5) * 2.0 - return image + is_single = images.shape.ndims == 3 + with ops.name_scope(scope, 'preprocess', [images, height, width]): + if not images.dtype.is_floating: + images = math_ops.to_float(images) + images = (images - 128.0) / 128.0 + if is_single: + images = array_ops.expand_dims(images, axis=0) + resized = image_ops.resize_bilinear(images, [height, width]) + if is_single: + resized = array_ops.squeeze(resized, axis=0) + return resized def _kl_divergence(p, p_logits, q): @@ -211,9 +206,9 @@ def _default_graph_def_fn(): def run_inception(images, graph_def=None, default_graph_def_fn=_default_graph_def_fn, - image_size=INCEPTION_V3_DEFAULT_IMG_SIZE, - input_tensor=INCEPTION_V3_INPUT, - output_tensor=INCEPTION_V3_OUTPUT): + image_size=INCEPTION_DEFAULT_IMAGE_SIZE, + input_tensor=INCEPTION_INPUT, + output_tensor=INCEPTION_OUTPUT): """Run images through a pretrained Inception classifier. Args: @@ -317,19 +312,28 @@ def classifier_score(images, classifier_fn, num_batches=1): name='RunClassifier') logits = array_ops.concat(array_ops.unstack(logits), 0) logits.shape.assert_has_rank(2) + + # Use maximum precision for best results. + logits_dtype = logits.dtype + if logits_dtype != dtypes.float64: + logits = math_ops.cast(logits, dtypes.float64) + p = nn_ops.softmax(logits) q = math_ops.reduce_mean(p, axis=0) kl = _kl_divergence(p, logits, q) kl.shape.assert_has_rank(1) log_score = math_ops.reduce_mean(kl) + final_score = math_ops.exp(log_score) - return math_ops.exp(log_score) + if logits_dtype != dtypes.float64: + final_score = math_ops.cast(final_score, dtypes.float64) + return final_score inception_score = functools.partial( classifier_score, classifier_fn=functools.partial( - run_inception, output_tensor=INCEPTION_V3_OUTPUT)) + run_inception, output_tensor=INCEPTION_OUTPUT)) def trace_sqrt_product(sigma, sigma_v): @@ -470,4 +474,4 @@ def frechet_classifier_distance(real_images, frechet_inception_distance = functools.partial( frechet_classifier_distance, classifier_fn=functools.partial( - run_inception, output_tensor=INCEPTION_V3_FINAL_POOL)) + run_inception, output_tensor=INCEPTION_FINAL_POOL)) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 30285964a53c388d4f9aaf65b6cabed362b3b012..81fa2fc0f126647d2f01a1f4fc695d714eba2c75 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -68,7 +68,7 @@ def _expected_trace_sqrt_product(sigma, sigma_v): # A dummy GraphDef string with the minimum number of Ops. graphdef_string = """ node { - name: "input" + name: "Mul" op: "Placeholder" attr { key: "dtype" @@ -97,7 +97,7 @@ node { } } node { - name: "InceptionV3/Logits/SpatialSqueeze" + name: "logits" op: "Placeholder" attr { key: "dtype" @@ -120,7 +120,7 @@ node { } } node { - name: "InceptionV3/Logits/AvgPool_1a_8x8/AvgPool" + name: "pool_3" op: "Placeholder" attr { key: "dtype" @@ -182,7 +182,7 @@ class ClassifierMetricsTest(test.TestCase): img = array_ops.ones([batch_size, 299, 299, 3]) pool = _run_with_mock( classifier_metrics.run_inception, img, - output_tensor=classifier_metrics.INCEPTION_V3_FINAL_POOL) + output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) self.assertTrue(isinstance(pool, ops.Tensor)) pool.shape.assert_is_compatible_with([batch_size, 2048]) @@ -306,7 +306,7 @@ class ClassifierMetricsTest(test.TestCase): """Test `preprocess_image` graph construction.""" incorrectly_sized_image = array_ops.zeros([520, 240, 3]) correct_image = classifier_metrics.preprocess_image( - image=incorrectly_sized_image) + images=incorrectly_sized_image) _run_with_mock(classifier_metrics.run_inception, array_ops.expand_dims(correct_image, 0)) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 2a40dbade6771674ec6d1c47a8b519ed186cb5e6..940762cf2aa0f473cd41d9d543e2773b565a5248 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -217,21 +217,25 @@ def acgan_discriminator_loss( Raises: TypeError: If the discriminator does not output a tuple. """ - loss_on_generated = losses.softmax_cross_entropy( - one_hot_labels, discriminator_gen_classification_logits, - weights=generated_weights, scope=scope, loss_collection=None, - reduction=reduction) - loss_on_real = losses.softmax_cross_entropy( - one_hot_labels, discriminator_real_classification_logits, - weights=real_weights, label_smoothing=label_smoothing, scope=scope, - loss_collection=None, reduction=reduction) - loss = loss_on_generated + loss_on_real - util.add_loss(loss, loss_collection) + with ops.name_scope( + scope, 'acgan_discriminator_loss', + (discriminator_real_classification_logits, + discriminator_gen_classification_logits, one_hot_labels)) as scope: + loss_on_generated = losses.softmax_cross_entropy( + one_hot_labels, discriminator_gen_classification_logits, + weights=generated_weights, scope=scope, loss_collection=None, + reduction=reduction) + loss_on_real = losses.softmax_cross_entropy( + one_hot_labels, discriminator_real_classification_logits, + weights=real_weights, label_smoothing=label_smoothing, scope=scope, + loss_collection=None, reduction=reduction) + loss = loss_on_generated + loss_on_real + util.add_loss(loss, loss_collection) - if add_summaries: - summary.scalar('discriminator_gen_ac_loss', loss_on_generated) - summary.scalar('discriminator_real_ac_loss', loss_on_real) - summary.scalar('discriminator_ac_loss', loss) + if add_summaries: + summary.scalar('discriminator_gen_ac_loss', loss_on_generated) + summary.scalar('discriminator_real_ac_loss', loss_on_real) + summary.scalar('discriminator_ac_loss', loss) return loss @@ -275,12 +279,16 @@ def acgan_generator_loss( ValueError: if arg module not either `generator` or `discriminator` TypeError: if the discriminator does not output a tuple. """ - loss = losses.softmax_cross_entropy( - one_hot_labels, discriminator_gen_classification_logits, weights=weights, - scope=scope, loss_collection=loss_collection, reduction=reduction) + with ops.name_scope( + scope, 'acgan_generator_loss', + (discriminator_gen_classification_logits, one_hot_labels)) as scope: + loss = losses.softmax_cross_entropy( + one_hot_labels, discriminator_gen_classification_logits, + weights=weights, scope=scope, loss_collection=loss_collection, + reduction=reduction) - if add_summaries: - summary.scalar('generator_ac_loss', loss) + if add_summaries: + summary.scalar('generator_ac_loss', loss) return loss @@ -289,7 +297,6 @@ def acgan_generator_loss( # GANs` (https://arxiv.org/abs/1704.00028). -# TODO(joelshor): Figure out why this function can't be inside a name scope. def wasserstein_gradient_penalty( real_data, generated_data, @@ -331,48 +338,50 @@ def wasserstein_gradient_penalty( Raises: ValueError: If the rank of data Tensors is unknown. """ - real_data = ops.convert_to_tensor(real_data) - generated_data = ops.convert_to_tensor(generated_data) - if real_data.shape.ndims is None: - raise ValueError('`real_data` can\'t have unknown rank.') - if generated_data.shape.ndims is None: - raise ValueError('`generated_data` can\'t have unknown rank.') - - differences = generated_data - real_data - batch_size = differences.shape[0].value or array_ops.shape(differences)[0] - alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1) - alpha = random_ops.random_uniform(shape=alpha_shape) - interpolates = real_data + (alpha * differences) - - # Reuse variables if a discriminator scope already exists. - reuse = False if discriminator_scope is None else True - with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope', - reuse=reuse): - disc_interpolates = discriminator_fn(interpolates, generator_inputs) - - if isinstance(disc_interpolates, tuple): - # ACGAN case: disc outputs more than one tensor - disc_interpolates = disc_interpolates[0] - - gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0] - gradient_squares = math_ops.reduce_sum( - math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims))) - # Propagate shape information, if possible. - if isinstance(batch_size, int): - gradient_squares.set_shape([ - batch_size] + gradient_squares.shape.as_list()[1:]) - # For numerical stability, add epsilon to the sum before taking the square - # root. Note tf.norm does not add epsilon. - slopes = math_ops.sqrt(gradient_squares + epsilon) - penalties = math_ops.square(slopes - 1.0) - penalty = losses.compute_weighted_loss( - penalties, weights, scope=scope, loss_collection=loss_collection, - reduction=reduction) + with ops.name_scope(scope, 'wasserstein_gradient_penalty', + (real_data, generated_data)) as scope: + real_data = ops.convert_to_tensor(real_data) + generated_data = ops.convert_to_tensor(generated_data) + if real_data.shape.ndims is None: + raise ValueError('`real_data` can\'t have unknown rank.') + if generated_data.shape.ndims is None: + raise ValueError('`generated_data` can\'t have unknown rank.') + + differences = generated_data - real_data + batch_size = differences.shape[0].value or array_ops.shape(differences)[0] + alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1) + alpha = random_ops.random_uniform(shape=alpha_shape) + interpolates = real_data + (alpha * differences) + + with ops.name_scope(None): # Clear scope so update ops are added properly. + # Reuse variables if variables already exists. + with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope', + reuse=variable_scope.AUTO_REUSE): + disc_interpolates = discriminator_fn(interpolates, generator_inputs) + + if isinstance(disc_interpolates, tuple): + # ACGAN case: disc outputs more than one tensor + disc_interpolates = disc_interpolates[0] + + gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0] + gradient_squares = math_ops.reduce_sum( + math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims))) + # Propagate shape information, if possible. + if isinstance(batch_size, int): + gradient_squares.set_shape([ + batch_size] + gradient_squares.shape.as_list()[1:]) + # For numerical stability, add epsilon to the sum before taking the square + # root. Note tf.norm does not add epsilon. + slopes = math_ops.sqrt(gradient_squares + epsilon) + penalties = math_ops.square(slopes - 1.0) + penalty = losses.compute_weighted_loss( + penalties, weights, scope=scope, loss_collection=loss_collection, + reduction=reduction) - if add_summaries: - summary.scalar('gradient_penalty_loss', penalty) + if add_summaries: + summary.scalar('gradient_penalty_loss', penalty) - return penalty + return penalty # Original losses from `Generative Adversarial Nets` @@ -546,7 +555,7 @@ def modified_generator_loss( discriminator_gen_outputs, label_smoothing=0.0, weights=1.0, - scope='generator_modified_loss', + scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=False): @@ -576,12 +585,15 @@ def modified_generator_loss( Returns: A loss Tensor. The shape depends on `reduction`. """ - loss = losses.sigmoid_cross_entropy( - array_ops.ones_like(discriminator_gen_outputs), discriminator_gen_outputs, - weights, label_smoothing, scope, loss_collection, reduction) + with ops.name_scope(scope, 'generator_modified_loss', + [discriminator_gen_outputs]) as scope: + loss = losses.sigmoid_cross_entropy( + array_ops.ones_like(discriminator_gen_outputs), + discriminator_gen_outputs, weights, label_smoothing, scope, + loss_collection, reduction) - if add_summaries: - summary.scalar('generator_modified_loss', loss) + if add_summaries: + summary.scalar('generator_modified_loss', loss) return loss @@ -739,7 +751,7 @@ def mutual_information_penalty( structured_generator_inputs, predicted_distributions, weights=1.0, - scope='generator_modified_loss', + scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=False): @@ -767,15 +779,16 @@ def mutual_information_penalty( _validate_information_penalty_inputs( structured_generator_inputs, predicted_distributions) - # Calculate the negative log-likelihood of the reconstructed noise. - log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in - zip(predicted_distributions, structured_generator_inputs)] - loss = -1 * losses.compute_weighted_loss( - log_probs, weights, scope, loss_collection=loss_collection, - reduction=reduction) + with ops.name_scope(scope, 'mutual_information_loss') as scope: + # Calculate the negative log-likelihood of the reconstructed noise. + log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in + zip(predicted_distributions, structured_generator_inputs)] + loss = -1 * losses.compute_weighted_loss( + log_probs, weights, scope, loss_collection=loss_collection, + reduction=reduction) - if add_summaries: - summary.scalar('mutual_information_penalty', loss) + if add_summaries: + summary.scalar('mutual_information_penalty', loss) return loss diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 3e003dd0f808f80dcc486e78e8e101ac6f198947..b5cd8c92ba180e981e0faf877021cb6d69dc34b4 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -274,8 +274,8 @@ class ACGANLossTest(test.TestCase): self._discriminator_real_classification_logits, 'one_hot_labels': self._one_hot_labels, } - self._generator_loss_name = 'softmax_cross_entropy_loss/value' - self._discriminator_loss_name = 'add' + self._generator_loss_name = 'acgan_generator_loss/value' + self._discriminator_loss_name = 'acgan_discriminator_loss/add' self._expected_g_loss = 3.84974 self._expected_d_loss = 9.43950 @@ -453,10 +453,11 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): 'discriminator_scope': self._scope, } self._expected_loss = 9.00000 - self._expected_op_name = 'weighted_loss/value' + self._expected_op_name = 'wasserstein_gradient_penalty/value' self._batch_size = 1 def _discriminator_fn(self, inputs, _): + ops.add_to_collection('fake_update_ops', constant_op.constant(1.0)) return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs def test_loss_with_placeholder(self): @@ -487,6 +488,26 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): self.assertEqual( num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) + def test_works_with_get_collection(self): + """Tests that gradient penalty works inside other scopes.""" + # We ran the discriminator once in the setup, so there should be an op + # already in the collection. + self.assertEqual(1, len(ops.get_collection( + 'fake_update_ops', self._kwargs['discriminator_scope'].name))) + + # Make sure the op is added to the collection even if it's in a name scope. + with ops.name_scope('loss'): + tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) + self.assertEqual(2, len(ops.get_collection( + 'fake_update_ops', self._kwargs['discriminator_scope'].name))) + + # Make sure the op is added to the collection even if it's in a variable + # scope. + with variable_scope.variable_scope('loss_vscope'): + tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) + self.assertEqual(3, len(ops.get_collection( + 'fake_update_ops', self._kwargs['discriminator_scope'].name))) + class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): """Tests for mutual_information_penalty.""" @@ -504,7 +525,7 @@ class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): 'predicted_distributions': self._predicted_distributions, } self._expected_loss = 1.61610 - self._expected_op_name = 'mul' + self._expected_op_name = 'mutual_information_loss/mul' self._batch_size = 2 diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 27512526c4ed11ffac60c0a0db5d3d1c381a8217..48f5e8e47dbcd5d32c23806b967a0d1e7403d2f7 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -120,7 +120,7 @@ class GANLoss( """GANLoss contains the generator and discriminator losses. Args: - generator_loss: A tensor for the generator loss.. + generator_loss: A tensor for the generator loss. discriminator_loss: A tensor for the discriminator loss. """ diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index bebcf079ba444946bf0377106cbafcbaa7e94e74..a8053be69b716c27d48efe7c9ec6c5675b5dd614 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -119,7 +119,6 @@ cc_library( ":gdr_memory_manager", ":gdr_rendezvous_mgr", ":gdr_worker", - "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], alwayslink = 1, diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD index b4c53d3da655e2f52b5990ac0de3bc7ccc823bcc..967ad2fc090906e93f22c777816eede37f9a1b04 100644 --- a/tensorflow/contrib/graph_editor/BUILD +++ b/tensorflow/contrib/graph_editor/BUILD @@ -144,12 +144,12 @@ py_test( ":graph_editor_py", ":match", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", + "//tensorflow/python:session", "//tensorflow/python:variables", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/graph_editor/reroute.py b/tensorflow/contrib/graph_editor/reroute.py index 42968ae63b769f7cea7385933fbadb0782cc86f3..7ffdbb7139281734917fdb715601b317eb58b82f 100644 --- a/tensorflow/contrib/graph_editor/reroute.py +++ b/tensorflow/contrib/graph_editor/reroute.py @@ -397,27 +397,57 @@ def swap_inputs(sgv0, sgv1): def reroute_inputs(sgv0, sgv1): - """Re-route all the inputs of sgv0 to sgv1 (see reroute_inputs).""" + """Re-route all the inputs of two subgraphs. + + Args: + sgv0: the first subgraph to have its inputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + sgv1: the second subgraph to have its inputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + Returns: + A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. + Note that the function argument sgv0 and sgv1 are also modified in place. + Raises: + StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using + the same rules than the function subgraph.make_view. + """ return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.a2b) def swap_outputs(sgv0, sgv1): - """Swap all the outputs of sgv0 and sgv1 (see _reroute_outputs).""" + """Swap all the outputs of sgv0 and sgv1 (see reroute_outputs).""" return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.swap) def reroute_outputs(sgv0, sgv1): - """Re-route all the outputs of sgv0 to sgv1 (see _reroute_outputs).""" + """Re-route all the outputs of two operations. + + Args: + sgv0: the first subgraph to have its outputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + sgv1: the second subgraph to have its outputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + Returns: + A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped. + Note that the function argument sgv0 and sgv1 are also modified in place. + Raises: + StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using + the same rules than the function subgraph.make_view. + """ return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.a2b) def swap_ios(sgv0, sgv1): - """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute).""" + """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute_sgv).""" return _reroute_sgv(sgv0, sgv1, _RerouteMode.swap) def reroute_ios(sgv0, sgv1): - """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute).""" + """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute_sgv).""" return _reroute_sgv(sgv0, sgv1, _RerouteMode.a2b) diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index ab5776b9dd66bb082e9ca3922e8902bfebe6b0b8..ca00394388f67e2ed9508684a47b23c3ee9e79e8 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -191,14 +191,14 @@ class TransformTest(test.TestCase): # Extract the operations. replacement_ts = {w.value(): g} original_mul1_grad = (ops.get_default_graph(). - get_operation_by_name("grad/mul1_grad/mul_1")) + get_operation_by_name("grad/mul1_grad/Mul_1")) # Should not raise exception. res = ge.graph_replace(g, replacement_ts, dst_scope="res") # Extract the operations after graph_replace. result_mul1_grad = (ops.get_default_graph(). - get_operation_by_name("res/grad/mul1_grad/mul_1")) + get_operation_by_name("res/grad/mul1_grad/Mul_1")) # Make sure _original_ops are as expected. self.assertEquals(original_mul1_grad._original_op.name, u"mul1") diff --git a/tensorflow/contrib/hooks/BUILD b/tensorflow/contrib/hooks/BUILD index d81e868d4a922698e4755733b999112088fa2a0b..1b528d7afc1112f5dc0667ae299ade02bc8fd04b 100644 --- a/tensorflow/contrib/hooks/BUILD +++ b/tensorflow/contrib/hooks/BUILD @@ -19,30 +19,11 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", - "//tensorflow/python:platform", "//tensorflow/python:training", "//tensorflow/python:util", ], ) -py_test( - name = "profiler_hook_test", - size = "small", - srcs = ["python/training/profiler_hook_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":hooks", - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - ], -) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/hooks/python/training/profiler_hook.py b/tensorflow/contrib/hooks/python/training/profiler_hook.py index 35aa25edfde6f2ed7051ed75ff4f53f8732ae76e..6173aa0797138730e79b21bc9a1779d346edab6b 100644 --- a/tensorflow/contrib/hooks/python/training/profiler_hook.py +++ b/tensorflow/contrib/hooks/python/training/profiler_hook.py @@ -12,93 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Additional `SessionRunHook` implementations to complement those in -tensorflow/python/training. - -""" +"""Placeholder of ProfilerHook for backward compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os.path - -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import timeline -from tensorflow.python.platform import gfile -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer -from tensorflow.python.training.session_run_hook import SessionRunArgs -from tensorflow.python.training import session_run_hook -from tensorflow.python.training import training_util - - -class ProfilerHook(session_run_hook.SessionRunHook): - """Captures CPU/GPU profiling information every N steps or seconds. - - This produces files called "timeline-.json", which are in Chrome - Trace format. - - For more information see: - https://github.com/catapult-project/catapult/blob/master/tracing/README.md""" - - def __init__(self, - save_steps=None, - save_secs=None, - output_dir="", - show_dataflow=True, - show_memory=False): - """Initializes a hook that takes periodic profiling snapshots. - - Args: - save_steps: `int`, save profile traces every N steps. Exactly one of - `save_secs` and `save_steps` should be set. - save_secs: `int`, save profile traces every N seconds. - output_dir: `string`, the directory to save the profile traces to. - Defaults to the current directory. - show_dataflow: `bool`, if True, add flow events to the trace connecting - producers and consumers of tensors. - show_memory: `bool`, if True, add object snapshot events to the trace - showing the sizes and lifetimes of tensors. - """ - self._output_file = os.path.join(output_dir, "timeline-{}.json") - self._show_dataflow = show_dataflow - self._show_memory = show_memory - self._timer = SecondOrStepTimer(every_secs=save_secs, - every_steps=save_steps) - - def begin(self): - self._next_step = None - self._global_step_tensor = training_util.get_global_step() - if self._global_step_tensor is None: - raise RuntimeError( - "Global step should be created to use ProfilerHook.") - - def before_run(self, run_context): - self._request_summary = ( - self._next_step is None or - self._timer.should_trigger_for_step(self._next_step)) - requests = {"global_step": self._global_step_tensor} - opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) - if self._request_summary else None) - - return SessionRunArgs(requests, options=opts) - - def after_run(self, run_context, run_values): - global_step = run_values.results["global_step"] - - if self._request_summary: - self._timer.update_last_triggered_step(global_step) - self._save(global_step, - self._output_file.format(global_step), - run_values.run_metadata.step_stats) - - self._next_step = global_step + 1 +from tensorflow.python.training import basic_session_run_hooks - def _save(self, step, save_path, step_stats): - logging.info("Saving timeline for %d into '%s'.", step, save_path) - with gfile.Open(save_path, "w") as f: - trace = timeline.Timeline(step_stats) - f.write(trace.generate_chrome_trace_format( - show_dataflow=self._show_dataflow, - show_memory=self._show_memory)) +ProfilerHook = basic_session_run_hooks.ProfilerHook # pylint: disable=invalid-name diff --git a/tensorflow/contrib/hooks/python/training/profiler_hook_test.py b/tensorflow/contrib/hooks/python/training/profiler_hook_test.py deleted file mode 100644 index e7ecb5eb2fcc56f14f3d5babe2c22652159afd76..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/hooks/python/training/profiler_hook_test.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for profiler_hook.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os.path -import shutil -import tempfile - -from tensorflow.contrib.framework.python.ops import variables -from tensorflow.contrib.hooks.python.training import ProfilerHook -from tensorflow.python.framework import ops -from tensorflow.python.ops import state_ops -from tensorflow.python.platform import gfile -from tensorflow.python.platform import test -from tensorflow.python.training import monitored_session - - -class ProfilerHookTest(test.TestCase): - - def setUp(self): - super(ProfilerHookTest, self).setUp() - self.output_dir = tempfile.mkdtemp() - self.graph = ops.Graph() - self.filepattern = os.path.join(self.output_dir, "timeline-*.json") - with self.graph.as_default(): - self.global_step = variables.get_or_create_global_step() - self.train_op = state_ops.assign_add(self.global_step, 1) - - def tearDown(self): - super(ProfilerHookTest, self).tearDown() - shutil.rmtree(self.output_dir, ignore_errors=True) - - def _count_timeline_files(self): - return len(gfile.Glob(self.filepattern)) - - def test_raise_in_both_secs_and_steps(self): - with self.assertRaises(ValueError): - ProfilerHook(save_secs=10, save_steps=20) - - def test_raise_in_none_secs_and_steps(self): - with self.assertRaises(ValueError): - ProfilerHook(save_secs=None, save_steps=None) - - def test_save_secs_saves_in_first_step(self): - with self.graph.as_default(): - hook = ProfilerHook(save_secs=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) - self.assertEqual(1, self._count_timeline_files()) - - @test.mock.patch('time.time') - def test_save_secs_saves_periodically(self, mock_time): - # Pick a fixed start time. - current_time = 1484863632.320497 - - with self.graph.as_default(): - mock_time.return_value = current_time - hook = ProfilerHook(save_secs=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) # Saved. - self.assertEqual(1, self._count_timeline_files()) - sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) - # Simulate 2.5 seconds of sleep. - mock_time.return_value = current_time + 2.5 - sess.run(self.train_op) # Saved. - - # Pretend some small amount of time has passed. - mock_time.return_value = current_time + 0.1 - sess.run(self.train_op) # Not saved. - # Edge test just before we should save the timeline. - mock_time.return_value = current_time + 1.9 - sess.run(self.train_op) # Not saved. - self.assertEqual(2, self._count_timeline_files()) - - mock_time.return_value = current_time + 4.5 - sess.run(self.train_op) # Saved. - self.assertEqual(3, self._count_timeline_files()) - - def test_save_steps_saves_in_first_step(self): - with self.graph.as_default(): - hook = ProfilerHook(save_secs=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) # Saved. - sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) - - def test_save_steps_saves_periodically(self): - with self.graph.as_default(): - hook = ProfilerHook(save_steps=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - self.assertEqual(0, self._count_timeline_files()) - sess.run(self.train_op) # Saved. - self.assertEqual(1, self._count_timeline_files()) - sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) - sess.run(self.train_op) # Saved. - self.assertEqual(2, self._count_timeline_files()) - sess.run(self.train_op) # Not saved. - self.assertEqual(2, self._count_timeline_files()) - sess.run(self.train_op) # Saved. - self.assertEqual(3, self._count_timeline_files()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index a18f14112e469b1cf83a046fa65b87e5c69fb88b..c0c56d2e4aca5eea4bc9b59d942b1b766ace0304 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -143,12 +143,13 @@ py_library( srcs_version = "PY2AND3", deps = [ ":distort_image_ops", + ":single_image_random_dot_stereograms_py", "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:image_ops", "//tensorflow/python:platform", "//tensorflow/python:random_ops", + "//tensorflow/python:util", ], ) @@ -211,6 +212,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":image_py", + ":single_image_random_dot_stereograms_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index 59a322d3ca6e7e53872f8e7e126e30923ddd77a0..d030dffadeb9d67f7ffcbc197a2a3feb9b3b122d 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -26,6 +26,8 @@ projective transforms (including rotation) are supported. @@random_yiq_hsv @@rotate @@transform +@@translate +@@translations_to_projective_transforms @@bipartite_match @@single_image_random_dot_stereograms """ @@ -41,6 +43,8 @@ from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_t from tensorflow.contrib.image.python.ops.image_ops import compose_transforms from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform +from tensorflow.contrib.image.python.ops.image_ops import translate +from tensorflow.contrib.image.python.ops.image_ops import translations_to_projective_transforms from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms import single_image_random_dot_stereograms from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index b8a0706b61449ebebeb2f1dc98b438f9dd620aa3..b50177ae5651fbc15f292e11031411c2074357ec 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -36,8 +36,8 @@ _DTYPES = set( class ImageOpsTest(test_util.TensorFlowTestCase): def test_zeros(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): for shape in [(5, 5), (24, 24), (2, 24, 24, 3)]: for angle in [0, 1, np.pi / 2.0]: image = array_ops.zeros(shape, dtype) @@ -46,8 +46,8 @@ class ImageOpsTest(test_util.TensorFlowTestCase): np.zeros(shape, dtype.as_numpy_dtype())) def test_rotate_even(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): image = array_ops.reshape( math_ops.cast(math_ops.range(36), dtype), (6, 6)) image_rep = array_ops.tile(image[None, :, :, None], [3, 1, 1, 1]) @@ -68,8 +68,8 @@ class ImageOpsTest(test_util.TensorFlowTestCase): [1, 7, 13, 19, 25, 31], [0, 6, 12, 18, 24, 30]]]) def test_rotate_odd(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): image = array_ops.reshape( math_ops.cast(math_ops.range(25), dtype), (5, 5)) image_rep = array_ops.tile(image[None, :, :, None], [3, 1, 1, 1]) @@ -87,9 +87,25 @@ class ImageOpsTest(test_util.TensorFlowTestCase): [22, 17, 12, 7, 2], [23, 18, 13, 8, 3], [24, 19, 14, 9, 4]]]) + def test_translate(self): + for dtype in _DTYPES: + with self.test_session(): + image = constant_op.constant( + [[1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1]], dtype=dtype) + translation = constant_op.constant([-1, -1], dtypes.float32) + image_translated = image_ops.translate(image, translation) + self.assertAllEqual(image_translated.eval(), + [[1, 0, 1, 0], + [0, 1, 0, 0], + [1, 0, 1, 0], + [0, 0, 0, 0]]) + def test_compose(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): image = constant_op.constant( [[1, 1, 1, 0], [1, 0, 0, 0], @@ -246,4 +262,3 @@ class BipartiteMatchTest(test_util.TensorFlowTestCase): if __name__ == "__main__": googletest.main() - diff --git a/tensorflow/contrib/image/python/ops/distort_image_ops.py b/tensorflow/contrib/image/python/ops/distort_image_ops.py index 39f023a2b40a1a8481217fe8fa191a5072e7a3ff..06e8e4ee720d04f4b29a25f833297bb17a7d239c 100644 --- a/tensorflow/contrib/image/python/ops/distort_image_ops.py +++ b/tensorflow/contrib/image/python/ops/distort_image_ops.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.image.ops import gen_distort_image_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -132,7 +133,7 @@ def adjust_hsv_in_yiq(image, orig_dtype = image.dtype flt_image = image_ops.convert_image_dtype(image, dtypes.float32) - rgb_altered = _distort_image_ops.adjust_hsv_in_yiq( + rgb_altered = gen_distort_image_ops.adjust_hsv_in_yiq( flt_image, delta_hue, scale_saturation, scale_value) return image_ops.convert_image_dtype(rgb_altered, orig_dtype) diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index aef3e385b57486d5cb3cb13d9e8b9519768abd7c..011ddeaa9a1eebaa507c9e0d33f9546ff3497166 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -37,16 +37,18 @@ _IMAGE_DTYPES = set( ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) -def rotate(images, angles, interpolation="NEAREST"): +def rotate(images, angles, interpolation="NEAREST", name=None): """Rotate image(s) by the passed angle(s) in radians. Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). + (num_rows, num_columns) (HW). The rank must be statically known (the + shape is not `TensorShape(None)`. angles: A scalar angle to rotate all images by, or (if images has rank 4) a vector of length num_images, with an angle for each image in the batch. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". + name: The name of the op. Returns: Image(s) with the same type and shape as `images`, rotated by the given @@ -55,38 +57,77 @@ def rotate(images, angles, interpolation="NEAREST"): Raises: TypeError: If `image` is an invalid type. """ - image_or_images = ops.convert_to_tensor(images, name="images") - if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: - raise TypeError("Invalid dtype %s." % image_or_images.dtype) - if len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") - - image_height = math_ops.cast(array_ops.shape(images)[1], dtypes.float32)[None] - image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None] - output = transform( - images, - angles_to_projective_transforms(angles, image_height, image_width), - interpolation=interpolation) - if len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + with ops.name_scope(name, "rotate"): + image_or_images = ops.convert_to_tensor(images) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + elif image_or_images.get_shape().ndims is None: + raise TypeError("image_or_images rank must be statically known") + elif len(image_or_images.get_shape()) == 2: + images = image_or_images[None, :, :, None] + elif len(image_or_images.get_shape()) == 3: + images = image_or_images[None, :, :, :] + elif len(image_or_images.get_shape()) == 4: + images = image_or_images + else: + raise TypeError("Images should have rank between 2 and 4.") + + image_height = math_ops.cast(array_ops.shape(images)[1], + dtypes.float32)[None] + image_width = math_ops.cast(array_ops.shape(images)[2], + dtypes.float32)[None] + output = transform( + images, + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation) + if image_or_images.get_shape().ndims is None: + raise TypeError("image_or_images rank must be statically known") + elif len(image_or_images.get_shape()) == 2: + return output[0, :, :, 0] + elif len(image_or_images.get_shape()) == 3: + return output[0, :, :, :] + else: + return output + + +def translate(images, translations, interpolation="NEAREST", name=None): + """Translate image(s) by the passed vectors(s). + Args: + images: A tensor of shape (num_images, num_rows, num_columns, num_channels) + (NHWC), (num_rows, num_columns, num_channels) (HWC), or + (num_rows, num_columns) (HW). The rank must be statically known (the + shape is not `TensorShape(None)`. + translations: A vector representing [dx, dy] or (if images has rank 4) + a matrix of length num_images, with a [dx, dy] vector for each image in + the batch. + interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". + name: The name of the op. -def angles_to_projective_transforms(angles, image_height, image_width): + Returns: + Image(s) with the same type and shape as `images`, translated by the given + vector(s). Empty space due to the translation will be filled with zeros. + + Raises: + TypeError: If `image` is an invalid type. + """ + with ops.name_scope(name, "translate"): + return transform( + images, + translations_to_projective_transforms(translations), + interpolation=interpolation) + + +def angles_to_projective_transforms(angles, + image_height, + image_width, + name=None): """Returns projective transform(s) for the given angle(s). Args: angles: A scalar angle to rotate all images by, or (for batches of images) - a vector with an angle to rotate each image in the batch. + a vector with an angle to rotate each image in the batch. The rank must + be statically known (the shape is not `TensorShape(None)`. image_height: Height of the image(s) to be transformed. image_width: Width of the image(s) to be transformed. @@ -94,41 +135,89 @@ def angles_to_projective_transforms(angles, image_height, image_width): A tensor of shape (num_images, 8). Projective transforms which can be given to `tf.contrib.image.transform`. """ - angle_or_angles = ops.convert_to_tensor( - angles, name="angles", dtype=dtypes.float32) - if len(angle_or_angles.get_shape()) == 0: # pylint: disable=g-explicit-length-test - angles = angle_or_angles[None] - elif len(angle_or_angles.get_shape()) == 1: - angles = angle_or_angles - else: - raise TypeError("Angles should have rank 0 or 1.") - x_offset = ((image_width - 1) - (math_ops.cos(angles) * - (image_width - 1) - math_ops.sin(angles) * - (image_height - 1))) / 2.0 - y_offset = ((image_height - 1) - (math_ops.sin(angles) * - (image_width - 1) + math_ops.cos(angles) * - (image_height - 1))) / 2.0 - num_angles = array_ops.shape(angles)[0] - return array_ops.concat( - values=[ - math_ops.cos(angles)[:, None], - -math_ops.sin(angles)[:, None], - x_offset[:, None], - math_ops.sin(angles)[:, None], - math_ops.cos(angles)[:, None], - y_offset[:, None], - array_ops.zeros((num_angles, 2), dtypes.float32), - ], - axis=1) - - -def transform(images, transforms, interpolation="NEAREST"): + with ops.name_scope(name, "angles_to_projective_transforms"): + angle_or_angles = ops.convert_to_tensor( + angles, name="angles", dtype=dtypes.float32) + if len(angle_or_angles.get_shape()) == 0: # pylint: disable=g-explicit-length-test + angles = angle_or_angles[None] + elif len(angle_or_angles.get_shape()) == 1: + angles = angle_or_angles + else: + raise TypeError("Angles should have rank 0 or 1.") + x_offset = ((image_width - 1) - (math_ops.cos(angles) * + (image_width - 1) - math_ops.sin(angles) * + (image_height - 1))) / 2.0 + y_offset = ((image_height - 1) - (math_ops.sin(angles) * + (image_width - 1) + math_ops.cos(angles) * + (image_height - 1))) / 2.0 + num_angles = array_ops.shape(angles)[0] + return array_ops.concat( + values=[ + math_ops.cos(angles)[:, None], + -math_ops.sin(angles)[:, None], + x_offset[:, None], + math_ops.sin(angles)[:, None], + math_ops.cos(angles)[:, None], + y_offset[:, None], + array_ops.zeros((num_angles, 2), dtypes.float32), + ], + axis=1) + + +def translations_to_projective_transforms(translations, name=None): + """Returns projective transform(s) for the given translation(s). + + Args: + translations: A 2-element list representing [dx, dy] or a matrix of + 2-element lists representing [dx, dy] to translate for each image + (for a batch of images). The rank must be statically known (the shape + is not `TensorShape(None)`. + name: The name of the op. + + Returns: + A tensor of shape (num_images, 8) projective transforms which can be given + to `tf.contrib.image.transform`. + """ + with ops.name_scope(name, "translations_to_projective_transforms"): + translation_or_translations = ops.convert_to_tensor( + translations, name="translations", dtype=dtypes.float32) + if translation_or_translations.get_shape().ndims is None: + raise TypeError( + "translation_or_translations rank must be statically known") + elif len(translation_or_translations.get_shape()) == 1: + translations = translation_or_translations[None] + elif len(translation_or_translations.get_shape()) == 2: + translations = translation_or_translations + else: + raise TypeError("Translations should have rank 1 or 2.") + num_translations = array_ops.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # Translation matrices are always float32. + return array_ops.concat( + values=[ + array_ops.ones((num_translations, 1), dtypes.float32), + array_ops.zeros((num_translations, 1), dtypes.float32), + -translations[:, 0, None], + array_ops.zeros((num_translations, 1), dtypes.float32), + array_ops.ones((num_translations, 1), dtypes.float32), + -translations[:, 1, None], + array_ops.zeros((num_translations, 2), dtypes.float32), + ], + axis=1) + + +def transform(images, transforms, interpolation="NEAREST", name=None): """Applies the given transform(s) to the image(s). Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). + (num_rows, num_columns) (HW). The rank must be statically known (the + shape is not `TensorShape(None)`. transforms: Projective transform matrix/matrices. A vector of length 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point @@ -146,34 +235,40 @@ def transform(images, transforms, interpolation="NEAREST"): Raises: TypeError: If `image` is an invalid type. """ - image_or_images = ops.convert_to_tensor(images, name="images") - transform_or_transforms = ops.convert_to_tensor( - transforms, name="transforms", dtype=dtypes.float32) - if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: - raise TypeError("Invalid dtype %s." % image_or_images.dtype) - if len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") - - if len(transform_or_transforms.get_shape()) == 1: - transforms = transform_or_transforms[None] - elif len(transform_or_transforms.get_shape()) == 2: - transforms = transform_or_transforms - else: - raise TypeError("Transforms should have rank 1 or 2.") - output = gen_image_ops.image_projective_transform( - images, transforms, interpolation=interpolation.upper()) - if len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + with ops.name_scope(name, "transform"): + image_or_images = ops.convert_to_tensor(images, name="images") + transform_or_transforms = ops.convert_to_tensor( + transforms, name="transforms", dtype=dtypes.float32) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + elif image_or_images.get_shape().ndims is None: + raise TypeError("image_or_images rank must be statically known") + elif len(image_or_images.get_shape()) == 2: + images = image_or_images[None, :, :, None] + elif len(image_or_images.get_shape()) == 3: + images = image_or_images[None, :, :, :] + elif len(image_or_images.get_shape()) == 4: + images = image_or_images + else: + raise TypeError("Images should have rank between 2 and 4.") + + if len(transform_or_transforms.get_shape()) == 1: + transforms = transform_or_transforms[None] + elif transform_or_transforms.get_shape().ndims is None: + raise TypeError( + "transform_or_transforms rank must be statically known") + elif len(transform_or_transforms.get_shape()) == 2: + transforms = transform_or_transforms + else: + raise TypeError("Transforms should have rank 1 or 2.") + output = gen_image_ops.image_projective_transform( + images, transforms, interpolation=interpolation.upper()) + if len(image_or_images.get_shape()) == 2: + return output[0, :, :, 0] + elif len(image_or_images.get_shape()) == 3: + return output[0, :, :, :] + else: + return output def compose_transforms(*transforms): @@ -191,11 +286,12 @@ def compose_transforms(*transforms): order. """ assert transforms, "transforms cannot be empty" - composed = _flat_transforms_to_matrices(transforms[0]) - for tr in transforms[1:]: - # Multiply batches of matrices. - composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr)) - return _transform_matrices_to_flat(composed) + with ops.name_scope("compose_transforms"): + composed = _flat_transforms_to_matrices(transforms[0]) + for tr in transforms[1:]: + # Multiply batches of matrices. + composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr)) + return _transform_matrices_to_flat(composed) def _flat_transforms_to_matrices(transforms): @@ -211,8 +307,8 @@ def _flat_transforms_to_matrices(transforms): def _transform_matrices_to_flat(transform_matrices): # Flatten each matrix. - transforms = array_ops.reshape( - transform_matrices, constant_op.constant([-1, 9])) + transforms = array_ops.reshape(transform_matrices, + constant_op.constant([-1, 9])) # Divide each matrix by the last entry (normally 1). transforms /= transforms[:, 8:9] return transforms[:, :8] @@ -260,10 +356,10 @@ def _image_projective_transform_grad(op, grad): return [output, None] -def bipartite_match( - distance_mat, - num_valid_rows, - top_k=-1): +def bipartite_match(distance_mat, + num_valid_rows, + top_k=-1, + name="bipartite_match"): """Find bipartite matching based on a given distance matrix. A greedy bi-partite matching algorithm is used to obtain the matching with @@ -282,6 +378,7 @@ def bipartite_match( top_k: A scalar that specifies the number of top-k matches to retrieve. If set to be negative, then is set according to the maximum number of matches from `distance_mat`. + name: The name of the op. Returns: row_to_col_match_indices: A vector of length num_rows, which is the number @@ -292,7 +389,8 @@ def bipartite_match( If `col_to_row_match_indices[j]` is not -1, column j is matched to row `col_to_row_match_indices[j]`. """ - result = gen_image_ops.bipartite_match(distance_mat, num_valid_rows, top_k) + result = gen_image_ops.bipartite_match( + distance_mat, num_valid_rows, top_k, name=name) return result diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index 79261c5e7501566537ee9492b5aa64570599e862..5cccf26028ca6bf269dbc67a33075351edecb407 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.image.ops import gen_single_image_random_dot_stereograms_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader @@ -107,7 +108,7 @@ def single_image_random_dot_stereograms( 'depth_values' """ - result = _sirds_ops.single_image_random_dot_stereograms( + result = gen_single_image_random_dot_stereograms_ops.single_image_random_dot_stereograms( # pylint: disable=line-too-long depth_values=depth_values, hidden_surface_removal=hidden_surface_removal, convergence_dots_size=convergence_dots_size, diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index ae1402b0e6688a0f43278999d1d93282ea2a11a5..a2f320ab11291e4049c8367e1f133a4fbcb72a62 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -64,6 +64,7 @@ py_test( name = "kernel_estimators_test", srcs = ["python/kernel_estimators_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":kernel_methods", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 1b2a5cdd3871f8a7848ee5a8df70452e58cc84a2..0653e71d1244387c6b110462047631b58124d253 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -13,6 +13,8 @@ py_test( deps = [ "//tensorflow/contrib/kfac/python/ops:fisher_estimator", "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -86,7 +88,6 @@ py_test( deps = [ "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/contrib/kfac/python/ops:loss_functions", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", @@ -127,6 +128,20 @@ py_test( ], ) +py_test( + name = "loss_functions_test", + srcs = ["loss_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:loss_functions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index 281274d88473ecd32bd18813a8a7e6a09d2dcc77..b52a7b52a7efd4292ad514c5a744c4da07082142 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -20,42 +20,80 @@ from __future__ import print_function from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] + class EstimatorTest(test.TestCase): - def testEstimatorInitManualRegistration(self): - with ops.Graph().as_default(): - layer_collection = lc.LayerCollection() + def setUp(self): + self._graph = ops.Graph() + with self._graph.as_default(): + self.layer_collection = lc.LayerCollection() - inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32) - weights = variable_scope.get_variable( - 'w', shape=(2, 2), dtype=dtypes.float32) - bias = variable_scope.get_variable( - 'b', initializer=init_ops.zeros_initializer(), shape=(2, 1)) - output = math_ops.matmul(inputs, weights) + bias + self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32) + self.weights = variable_scope.get_variable( + "w", shape=(2, 2), dtype=dtypes.float32) + self.bias = variable_scope.get_variable( + "b", initializer=init_ops.zeros_initializer(), shape=(2, 1)) + self.output = math_ops.matmul(self.inputs, self.weights) + self.bias # Only register the weights. - layer_collection.register_fully_connected((weights,), inputs, output) + self.layer_collection.register_fully_connected( + params=(self.weights,), inputs=self.inputs, outputs=self.output) - outputs = math_ops.tanh(output) - layer_collection.register_categorical_predictive_distribution(outputs) + self.outputs = math_ops.tanh(self.output) + self.targets = array_ops.zeros_like(self.outputs) + self.layer_collection.register_categorical_predictive_distribution( + logits=self.outputs, targets=self.targets) + def testEstimatorInitManualRegistration(self): + with self._graph.as_default(): # We should be able to build an estimator for only the registered vars. - estimator.FisherEstimator([weights], 0.1, 0.2, layer_collection) + estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): - estimator.FisherEstimator([weights, bias], 0.1, 0.2, layer_collection) + estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2, + self.layer_collection) + + # Check that we throw an error if we don't include registered variables, + # i.e. self.weights + with self.assertRaises(ValueError): + estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection) + + @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) + def testVariableWrongNumberOfUses(self, mock_uses): + with self.assertRaises(ValueError): + estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) + + def testInvalidEstimationMode(self): + with self.assertRaises(ValueError): + estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, + "not_a_real_mode") + + def testModeListCorrect(self): + with self._graph.as_default(): + est = estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection) + self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys()) + + def testAllModesBuild(self): + for mode in _ALL_ESTIMATION_MODES: + with self._graph.as_default(): + estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection, mode) -if __name__ == '__main__': +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index f48d1980babe283d5bb6e911bdabc469481a74fb..85ac08a1eb7df7da8b4ceedd18929f6c3293c055 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -209,6 +209,146 @@ class NaiveDiagonalFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) +class FullyConnectedDiagonalFB(test.TestCase): + + def setUp(self): + super(FullyConnectedDiagonalFB, self).setUp() + + self.batch_size = 4 + self.input_size = 6 + self.output_size = 3 + + self.inputs = np.random.randn(self.batch_size, self.input_size).astype( + np.float32) + self.outputs = np.zeros([self.batch_size, self.output_size]).astype( + np.float32) + self.output_grads = np.random.randn(self.batch_size, + self.output_size).astype(np.float32) + self.w = np.random.randn(self.input_size, self.output_size).astype( + np.float32) + self.b = np.random.randn(self.output_size).astype(np.float32) + + def fisherApprox(self, has_bias=False): + """Fisher approximation using default inputs.""" + if has_bias: + inputs = np.concatenate( + [self.inputs, np.ones([self.batch_size, 1])], axis=1) + else: + inputs = self.inputs + return self.buildDiagonalFisherApproximation(inputs, self.output_grads) + + def buildDiagonalFisherApproximation(self, inputs, output_grads): + """Builds explicit diagonal Fisher approximation. + + Fisher's diagonal is (d loss / d w)'s elements squared for + d/dw = E[outer(input, output_grad)] + + where the expectation is taken over examples. + + Args: + inputs: np.array of shape [batch_size, input_size]. + output_grads: np.array of shape [batch_size, output_size]. + + Returns: + Diagonal np.array of shape [num_params, num_params] for num_params = + input_size * output_size. + """ + batch_size = inputs.shape[0] + assert output_grads.shape[0] == batch_size + input_size = inputs.shape[1] + output_size = output_grads.shape[1] + fisher_diag = np.zeros((input_size, output_size)) + for i in range(batch_size): + fisher_diag += np.square(np.outer(inputs[i], output_grads[i])) + return np.diag(fisher_diag.flatten()) / batch_size + + def testMultiply(self): + result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct Fisher-vector product. + expected_result = self.fisherApprox().dot(self.w.flatten()) + expected_result = expected_result.reshape( + [self.input_size, self.output_size]) + + self.assertAllClose(expected_result, result) + + def testMultiplyInverse(self): + _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct inverse Fisher-vector product. + expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) + expected_result = expected_result.reshape( + [self.input_size, self.output_size]) + + self.assertAllClose(expected_result, result) + + def testRegisterAdditionalMinibatch(self): + """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( + self.w, [self.inputs], [self.outputs], [self.output_grads]) + multiply_result_small, multiply_inverse_result_small = ( + self.runFisherBlockOps(self.w, + np.split(self.inputs, 2), + np.split(self.outputs, 2), + np.split(self.output_grads, 2))) + + self.assertAllClose(multiply_result_big, multiply_result_small) + self.assertAllClose(multiply_inverse_result_big, + multiply_inverse_result_small) + + def testMultiplyHasBias(self): + result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], + [self.outputs], [self.output_grads]) + expected_result = self.fisherApprox(True).dot( + np.concatenate([self.w.flatten(), self.b.flatten()])) + expected_result = expected_result.reshape( + [self.input_size + 1, self.output_size]) + expected_result = (expected_result[:-1], expected_result[-1]) + + self.assertEqual(len(result), 2) + self.assertAllClose(expected_result[0], result[0]) + self.assertAllClose(expected_result[1], result[1]) + + def runFisherBlockOps(self, params, inputs, outputs, output_grads): + """Run Ops guaranteed by FisherBlock interface. + + Args: + params: Tensor or 2-tuple of Tensors. Represents weights or weights and + bias of this layer. + inputs: list of Tensors of shape [batch_size, input_size]. Inputs to + layer. + outputs: list of Tensors of shape [batch_size, output_size]. + Preactivations produced by layer. + output_grads: list of Tensors of shape [batch_size, output_size]. + Gradient of loss with respect to 'outputs'. + + Returns: + multiply_result: Result of FisherBlock.multiply(params) + multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) + """ + with ops.Graph().as_default(), self.test_session() as sess: + inputs = as_tensors(inputs) + outputs = as_tensors(outputs) + output_grads = as_tensors(output_grads) + params = as_tensors(params) + + block = fb.FullyConnectedDiagonalFB( + lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) + for (i, o) in zip(inputs, outputs): + block.register_additional_minibatch(i, o) + + block.instantiate_factors((output_grads,), damping=0.0) + + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_covariance_update_op(0.0)) + multiply_result = sess.run(block.multiply(params)) + multiply_inverse_result = sess.run(block.multiply_inverse(params)) + + return multiply_result, multiply_inverse_result + + class FullyConnectedKFACBasicFBTest(test.TestCase): def testFullyConnectedKFACBasicFBInit(self): @@ -216,50 +356,51 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): random_seed.set_random_seed(200) inputs = array_ops.constant([1., 2.]) outputs = array_ops.constant([3., 4.]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), inputs, - outputs) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) + block.register_additional_minibatch(inputs, outputs) - self.assertAllEqual(outputs, block.tensors_to_compute_grads()) + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) def testInstantiateFactorsHasBias(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB( - lc.LayerCollection(), inputs, outputs, has_bias=True) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 - block.instantiate_factors((grads,), 0.5) + block.instantiate_factors(([grads],), 0.5) def testInstantiateFactorsNoBias(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB( - lc.LayerCollection(), inputs, outputs, has_bias=False) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 - block.instantiate_factors((grads,), 0.5) + block.instantiate_factors(([grads],), 0.5) def testMultiplyInverseTuple(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB( - lc.LayerCollection(), inputs, outputs, has_bias=False) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 - block.instantiate_factors((grads,), 0.5) + block.instantiate_factors(([grads],), 0.5) # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) sess.run(block._input_factor.make_inverse_update_ops()) sess.run(block._output_factor.make_inverse_update_ops()) - vector = (np.arange(2, 6).reshape(2, 2).astype(np.float32), np.arange( - 1, 3).reshape(2, 1).astype(np.float32)) + vector = ( + np.arange(2, 6).reshape(2, 2).astype(np.float32), # + np.arange(1, 3).reshape(2, 1).astype(np.float32)) output = block.multiply_inverse((array_ops.constant(vector[0]), array_ops.constant(vector[1]))) @@ -273,10 +414,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): random_seed.set_random_seed(200) inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB( - lc.LayerCollection(), inputs, outputs, has_bias=False) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 - block.instantiate_factors((grads,), 0.5) + block.instantiate_factors(([grads],), 0.5) # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -296,11 +437,11 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.zeros([32, input_dim]) outputs = array_ops.zeros([32, output_dim]) params = array_ops.zeros([input_dim, output_dim]) - block = fb.FullyConnectedKFACBasicFB( - lc.LayerCollection(), inputs, outputs, has_bias=False) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. - block.instantiate_factors((grads,), damping) + block.instantiate_factors(([grads],), damping) sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3))) sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) @@ -318,6 +459,188 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) +class ConvDiagonalFBTest(test.TestCase): + + def setUp(self): + super(ConvDiagonalFBTest, self).setUp() + + self.batch_size = 2 + self.height = 8 + self.width = 4 + self.input_channels = 6 + self.output_channels = 3 + self.kernel_size = 1 + + self.inputs = np.random.randn(self.batch_size, self.height, self.width, + self.input_channels).astype(np.float32) + self.outputs = np.zeros( + [self.batch_size, self.height, self.width, + self.output_channels]).astype(np.float32) + self.output_grads = np.random.randn( + self.batch_size, self.height, self.width, self.output_channels).astype( + np.float32) + self.w = np.random.randn(self.kernel_size, self.kernel_size, + self.input_channels, self.output_channels).astype( + np.float32) + self.b = np.random.randn(self.output_channels).astype(np.float32) + + def fisherApprox(self, has_bias=False): + """Fisher approximation using default inputs.""" + if has_bias: + inputs = np.concatenate( + [self.inputs, + np.ones([self.batch_size, self.height, self.width, 1])], + axis=-1) + else: + inputs = self.inputs + return self.buildDiagonalFisherApproximation(inputs, self.output_grads, + self.kernel_size) + + def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size): + r"""Builds explicit diagonal Fisher approximation. + + Fisher's diagonal is (d loss / d w)'s elements squared for + d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})] + + where the expectation is taken over examples and the sum over (x, y) + locations upon which the convolution is applied. + + Args: + inputs: np.array of shape [batch_size, height, width, input_channels]. + output_grads: np.array of shape [batch_size, height, width, + output_channels]. + kernel_size: int. height and width of kernel. + + Returns: + Diagonal np.array of shape [num_params, num_params] for num_params = + kernel_size^2 * input_channels * output_channels. + """ + batch_size, height, width, input_channels = inputs.shape + assert output_grads.shape[0] == batch_size + assert output_grads.shape[1] == height + assert output_grads.shape[2] == width + output_channels = output_grads.shape[3] + + # If kernel_size == 1, then we don't need to worry about capturing context + # around the pixel upon which a convolution is applied. This makes testing + # easier. + assert kernel_size == 1, "kernel_size != 1 isn't supported." + num_locations = height * width + inputs = np.reshape(inputs, [batch_size, num_locations, input_channels]) + output_grads = np.reshape(output_grads, + [batch_size, num_locations, output_channels]) + + fisher_diag = np.zeros((input_channels, output_channels)) + for i in range(batch_size): + # Each example's approximation is a square(sum-of-outer-products). + example_fisher_diag = np.zeros((input_channels, output_channels)) + for j in range(num_locations): + example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j]) + fisher_diag += np.square(example_fisher_diag) + + # Normalize by batch_size (not num_locations). + return np.diag(fisher_diag.flatten()) / batch_size + + def testMultiply(self): + result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct Fisher-vector product. + expected_result = self.fisherApprox().dot(self.w.flatten()) + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels, + self.output_channels + ]) + + self.assertAllClose(expected_result, result) + + def testMultiplyInverse(self): + _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct inverse Fisher-vector product. + expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels, + self.output_channels + ]) + + self.assertAllClose(expected_result, result, atol=1e-3) + + def testRegisterAdditionalMinibatch(self): + """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( + self.w, [self.inputs], [self.outputs], [self.output_grads]) + multiply_result_small, multiply_inverse_result_small = ( + self.runFisherBlockOps(self.w, + np.split(self.inputs, 2), + np.split(self.outputs, 2), + np.split(self.output_grads, 2))) + + self.assertAllClose(multiply_result_big, multiply_result_small) + self.assertAllClose(multiply_inverse_result_big, + multiply_inverse_result_small) + + def testMultiplyHasBias(self): + result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], + [self.outputs], [self.output_grads]) + # Clone 'b' along 'input_channels' dimension. + b_filter = np.tile( + np.reshape(self.b, [1, 1, 1, self.output_channels]), + [self.kernel_size, self.kernel_size, 1, 1]) + params = np.concatenate([self.w, b_filter], axis=2) + expected_result = self.fisherApprox(True).dot(params.flatten()) + + # Extract 'b' from concatenated parameters. + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels + 1, + self.output_channels + ]) + expected_result = (expected_result[:, :, 0:-1, :], np.reshape( + expected_result[:, :, -1, :], [self.output_channels])) + + self.assertEqual(len(result), 2) + self.assertAllClose(expected_result[0], result[0]) + self.assertAllClose(expected_result[1], result[1]) + + def runFisherBlockOps(self, params, inputs, outputs, output_grads): + """Run Ops guaranteed by FisherBlock interface. + + Args: + params: Tensor or 2-tuple of Tensors. Represents weights or weights and + bias of this layer. + inputs: list of Tensors of shape [batch_size, input_size]. Inputs to + layer. + outputs: list of Tensors of shape [batch_size, output_size]. + Preactivations produced by layer. + output_grads: list of Tensors of shape [batch_size, output_size]. + Gradient of loss with respect to 'outputs'. + + Returns: + multiply_result: Result of FisherBlock.multiply(params) + multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) + """ + with ops.Graph().as_default(), self.test_session() as sess: + inputs = as_tensors(inputs) + outputs = as_tensors(outputs) + output_grads = as_tensors(output_grads) + params = as_tensors(params) + + block = fb.ConvDiagonalFB( + lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') + for (i, o) in zip(inputs, outputs): + block.register_additional_minibatch(i, o) + + block.instantiate_factors((output_grads,), damping=0.0) + + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_covariance_update_op(0.0)) + multiply_result = sess.run(block.multiply(params)) + multiply_inverse_result = sess.run(block.multiply_inverse(params)) + + return multiply_result, multiply_inverse_result + + class ConvKFCBasicFBTest(test.TestCase): def _testConvKFCBasicFBInitParams(self, params): @@ -437,5 +760,11 @@ class ConvKFCBasicFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) +def as_tensors(tensor_or_tuple): + """Converts a potentially nested tuple of np.array to Tensors.""" + if isinstance(tensor_or_tuple, (tuple, list)): + return tuple(as_tensors(t) for t in tensor_or_tuple) + return ops.convert_to_tensor(tensor_or_tuple) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index 633104ace01dda6a6ba1ba058486ba39f18326e7..4f27ceced9836e415aff38460f3e5940cdf1414f 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -30,6 +30,43 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +class LayerParametersDictTest(test.TestCase): + + def testSetItem(self): + """Ensure insertion, contains, retrieval works for supported key types.""" + with ops.Graph().as_default(): + lp_dict = layer_collection.LayerParametersDict() + + x = array_ops.constant(0) + y0 = array_ops.constant(0) + y1 = array_ops.constant(0) + z0 = array_ops.constant(0) + z1 = array_ops.constant(0) + keys = [x, (y0, y1), [z0, z1]] + for key in keys: + lp_dict[key] = key + + for key in keys: + self.assertTrue(key in lp_dict) + self.assertEqual(lp_dict[key], key) + + def testSetItemOverlap(self): + """Ensure insertion fails if key overlaps with existing key.""" + with ops.Graph().as_default(): + lp_dict = layer_collection.LayerParametersDict() + + x = array_ops.constant(0) + y = array_ops.constant(0) + lp_dict[x] = 'value' + + with self.assertRaises(ValueError): + lp_dict[(x, y)] = 'value' + + # Ensure 'y' wasn't inserted. + self.assertTrue(x in lp_dict) + self.assertFalse(y in lp_dict) + + class LayerCollectionTest(test.TestCase): def testLayerCollectionInit(self): @@ -44,9 +81,18 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.register_fully_connected( array_ops.constant(1), array_ops.constant(2), array_ops.constant(3)) + lc.register_fully_connected( + array_ops.constant(1), + array_ops.constant(2), + array_ops.constant(3), + approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_conv2d( array_ops.constant(4), [1, 1, 1, 1], 'SAME', array_ops.ones((1, 1, 1, 1)), array_ops.constant(3)) + lc.register_conv2d( + array_ops.constant(4), [1, 1, 1, 1], 'SAME', + array_ops.ones((1, 1, 1, 1)), array_ops.constant(3), + approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_generic( array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) lc.register_generic( @@ -54,7 +100,7 @@ class LayerCollectionTest(test.TestCase): 16, approx=layer_collection.APPROX_DIAGONAL_NAME) - self.assertEqual(4, len(lc.get_blocks())) + self.assertEqual(6, len(lc.get_blocks())) def testRegisterBlocksMultipleRegistrations(self): with ops.Graph().as_default(): @@ -157,6 +203,83 @@ class LayerCollectionTest(test.TestCase): double_loss = sess.run(lc2.total_sampled_loss()) self.assertAlmostEqual(2 * single_loss, double_loss) + def testLossFunctionByName(self): + """Ensure loss functions can be identified by name.""" + with ops.Graph().as_default(): + logits = linalg_ops.eye(2) + lc = layer_collection.LayerCollection() + + # Create a new loss function by name. + lc.register_categorical_predictive_distribution(logits, name='loss1') + self.assertEqual(1, len(lc.losses)) + + # Add logits to same loss function. + lc.register_categorical_predictive_distribution( + logits, name='loss1', reuse=True) + self.assertEqual(1, len(lc.losses)) + + # Add another new loss function. + lc.register_categorical_predictive_distribution(logits, name='loss2') + self.assertEqual(2, len(lc.losses)) + + def testLossFunctionWithoutName(self): + """Ensure loss functions get unique names if 'name' not specified.""" + with ops.Graph().as_default(): + logits = linalg_ops.eye(2) + lc = layer_collection.LayerCollection() + + # Create a new loss function with default names. + lc.register_categorical_predictive_distribution(logits) + lc.register_categorical_predictive_distribution(logits) + self.assertEqual(2, len(lc.losses)) + + def testCategoricalPredictiveDistributionMultipleMinibatches(self): + """Ensure multiple minibatches are registered.""" + with ops.Graph().as_default(): + batch_size = 3 + output_size = 2 + logits = array_ops.zeros([batch_size, output_size]) + targets = array_ops.ones([batch_size], dtype=dtypes.int32) + lc = layer_collection.LayerCollection() + + # Create a new loss function. + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1') + + # Can add when reuse=True + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1', reuse=True) + + # Can add when reuse=VARIABLE_SCOPE and reuse=True there. + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): + lc.register_categorical_predictive_distribution( + logits, + targets=targets, + name='loss1', + reuse=layer_collection.VARIABLE_SCOPE) + + # Can't add when reuse=False + with self.assertRaises(KeyError): + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1', reuse=False) + + # Can't add when reuse=VARIABLE_SCOPE and reuse=False there. + with self.assertRaises(KeyError): + lc.register_categorical_predictive_distribution( + logits, + targets=targets, + name='loss1', + reuse=layer_collection.VARIABLE_SCOPE) + + self.assertEqual(len(lc.losses), 1) + loss = lc.losses[0] + + # Three successful registrations. + self.assertEqual(loss.params.shape.as_list(), + [3 * batch_size, output_size]) + self.assertEqual(loss.targets.shape.as_list(), [3 * batch_size]) + def testRegisterCategoricalPredictiveDistributionBatchSize1(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) @@ -206,6 +329,73 @@ class LayerCollectionTest(test.TestCase): single_loss = sess.run(lc.total_loss()) self.assertAlmostEqual(7.6983433, single_loss) + def testRegisterFullyConnectedReuse(self): + """Ensure the 'reuse' keyword argument function as intended.""" + with ops.Graph().as_default(): + inputs = [ + array_ops.ones([2, 10]), # + array_ops.zeros([5, 10]) + ] + outputs = [ + array_ops.zeros([2, 5]), # + array_ops.ones([5, 5]) + ] + params = ( + variable_scope.get_variable('w', [10, 5]), # + variable_scope.get_variable('b', [5])) + + # Fails on second if reuse=False. + lc = layer_collection.LayerCollection() + lc.register_fully_connected(params, inputs[0], outputs[0]) + with self.assertRaises(ValueError): + lc.register_fully_connected(params, inputs[1], outputs[1], reuse=False) + + # Succeeds on second if reuse=True. + lc = layer_collection.LayerCollection() + lc.register_fully_connected(params, inputs[0], outputs[0]) + lc.register_fully_connected(params, inputs[1], outputs[1], reuse=True) + + # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse. + lc = layer_collection.LayerCollection() + lc.register_fully_connected(params, inputs[0], outputs[0]) + with self.assertRaises(ValueError): + lc.register_fully_connected( + params, + inputs[1], + outputs[1], + reuse=layer_collection.VARIABLE_SCOPE) + + # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse. + lc = layer_collection.LayerCollection() + lc.register_fully_connected(params, inputs[0], outputs[0]) + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): + lc.register_fully_connected( + params, + inputs[1], + outputs[1], + reuse=layer_collection.VARIABLE_SCOPE) + + # Fails if block type changes. + lc = layer_collection.LayerCollection() + lc.register_fully_connected( + params, + inputs[0], + outputs[0], + approx=layer_collection.APPROX_KRONECKER_NAME) + with self.assertRaises(ValueError): + lc.register_fully_connected( + params, + inputs[1], + outputs[1], + approx=layer_collection.APPROX_DIAGONAL_NAME, + reuse=True) + + # Fails if reuse requested but no FisherBlock exists. + lc = layer_collection.LayerCollection() + with self.assertRaises(KeyError): + lc.register_fully_connected(params, inputs[0], outputs[0], reuse=True) + def testMakeOrGetFactor(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) @@ -237,10 +427,20 @@ class LayerCollectionTest(test.TestCase): self.assertTrue(all([var.name.startswith(scope) for var in variables])) def testGetUseCountMap(self): + """Ensure get_use_count_map() sums 'num_registered_minibatches'.""" + + class MockFisherBlock(object): + + num_registered_minibatches = 2 + lc = layer_collection.LayerCollection() - lc.fisher_blocks = {'a': 1, ('a', 'c'): 2, ('b', 'c'): 2} + lc.fisher_blocks = { + 'a': MockFisherBlock(), + ('a', 'c'): MockFisherBlock(), + ('b', 'c'): MockFisherBlock() + } use_count_map = lc.get_use_count_map() - self.assertDictEqual({'a': 2, 'b': 1, 'c': 2}, use_count_map) + self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map) if __name__ == '__main__': diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..87339cb059802ec8944d5d1ae4557ee34550cd60 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -0,0 +1,101 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.loss_functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import loss_functions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class InsertSliceInZerosTest(test.TestCase): + + def testBadShape(self): + bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1 + with self.assertRaises(ValueError): + loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17) + + def test3d(self): + input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]]) + expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]] + op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0) + with self.test_session() as sess: + actual_output_array = sess.run(op) + self.assertAllEqual(expected_output_array, actual_output_array) + + +class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): + + def testSample(self): + """Ensure samples can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + sample = loss.sample(42) + sample = sess.run(sample) + self.assertEqual(sample.shape, (2,)) + + def testEvaluateOnTargets(self): + """Ensure log probability can be evaluated correctly.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + targets = np.asarray([2, 1]).astype(np.int32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits), targets=array_ops.constant(targets)) + neg_log_prob = loss.evaluate() + neg_log_prob = sess.run(neg_log_prob) + + # Calculate explicit log probability of targets. + probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) + log_probs = np.log([ + probs[0, targets[0]], # + probs[1, targets[1]] + ]) + expected_log_prob = np.sum(log_probs) + + self.assertAllClose(neg_log_prob, -expected_log_prob) + + def testEvaluateOnSample(self): + """Ensure log probability of a sample can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + neg_log_prob = loss.evaluate_on_sample(42) + + # Simply ensure this doesn't crash. As the output is random, it's + # difficult to say if the output is correct or not... + neg_log_prob = sess.run(neg_log_prob) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py index 5f28f57f6a37074b40fddd690c71292b785490b6..9325aa1b7325fa9cf546d66e6505affa1af7db4d 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.kfac.python.ops import layer_collection as lc -from tensorflow.contrib.kfac.python.ops import loss_functions as lf from tensorflow.contrib.kfac.python.ops import optimizer from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -124,9 +123,8 @@ class OptimizerTest(test.TestCase): def testUpdateVelocities(self): with ops.Graph().as_default(), self.test_session() as sess: layers = lc.LayerCollection() - layers.losses = [ - lf.CategoricalLogitsNegativeLogProbLoss(array_ops.constant([1.0])) - ] + layers.register_categorical_predictive_distribution( + array_ops.constant([1.0])) opt = optimizer.KfacOptimizer( 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular') x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2))) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index 779a8179bb07303ff43eba064763c20b9be71dbe..55fe38e3e9aab2dbd70a45cdc8fa0c208b036db0 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -63,6 +63,39 @@ class SequenceDictTest(test.TestCase): self.assertItemsEqual(list(zip(keys, values)), seq_dict.items()) +class SubGraphTest(test.TestCase): + + def testBasicGraph(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + d = a * b + sub_graph = utils.SubGraph((c,)) + self.assertTrue(sub_graph.is_member(a)) + self.assertTrue(sub_graph.is_member(b)) + self.assertTrue(sub_graph.is_member(c)) + self.assertFalse(sub_graph.is_member(d)) + + def testRepeatedAdds(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + a # note that a appears twice in this graph + sub_graph = utils.SubGraph((c,)) + self.assertTrue(sub_graph.is_member(a)) + self.assertTrue(sub_graph.is_member(b)) + self.assertTrue(sub_graph.is_member(c)) + + def testFilterList(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + d = a * b + sub_graph = utils.SubGraph((c,)) + input_list = [b, d] + filtered_list = sub_graph.filter_list(input_list) + self.assertEqual(filtered_list, [b]) + + class UtilsTest(test.TestCase): def _fully_connected_layer_params(self): diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index 8b82f6e3147efbc204320f7be631448443287b1b..de4b8920b849dbf2117657de6e7c26f94f4d0363 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -66,6 +66,7 @@ py_library( deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", "//tensorflow/python/ops/distributions", "@six_archive//:six", ], @@ -89,6 +90,7 @@ py_library( ":utils", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", + "//tensorflow/python:util", ], ) @@ -113,7 +115,9 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py index bf59a92fa677810dad62c49e8085d1a8202b7fa0..21b5cde9b931a95110c9a5fd7930a3a4ee74b207 100644 --- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py +++ b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py @@ -36,13 +36,13 @@ class CurvatureMatrixVectorProductComputer(object): For example, the Fisher associated with a log-prob loss w.r.t. the parameters. - The vecs argument to each method are lists of tensors that must be the + The 'vecs' argument to each method are lists of tensors that must be the size as the corresponding ones from "wrt_tensors". They represent the vector being multiplied. "factors" of the matrix M are defined as matrices B such that B*B^T = M. - Methods that multiply by the factor B take a "loss_inner_vecs" argument - instead of vecs, which must be a list of tensors with shapes given by the + Methods that multiply by the factor B take a 'loss_inner_vecs' argument + instead of 'vecs', which must be a list of tensors with shapes given by the corresponding XXX_inner_shapes property. Note that matrix-vector products are not normalized by the batch size, nor @@ -61,7 +61,8 @@ class CurvatureMatrixVectorProductComputer(object): Args: losses: A list of LossFunction instances whose sum defines the total loss. wrt_tensors: A list of Tensors to compute the differential quantities - defining the matrices with respect to (see class description). + (defining the matrices) with respect to. See class description for more + info. """ self._losses = losses self._inputs_to_losses = list(loss.inputs for loss in losses) @@ -73,24 +74,23 @@ class CurvatureMatrixVectorProductComputer(object): return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses)) # Jacobian multiplication functions: - # NOTE: These implementations use tf.gradients and thus aren't actually - # computing partial derivatives, but total derivatives instead (despite what - # the documentation for tf.gradients says). Because we require partial - # derivatives for Jacobians this implementation will only be correct if the - # partial derivatives are equal to the full derivatives. This happens as long - # as the elements of wrt_tensors don't depend on each other in the graph. If - # these tensors are standard neural network parameters this will be true. def _multiply_jacobian(self, vecs): """Multiply vecs by the Jacobian of losses.""" + # We stop gradients at wrt_tensors to produce partial derivatives (which is + # what we want for Jacobians). jacobian_vecs_flat = utils.fwd_gradients( - self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs) + self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs, + stop_gradients=self._wrt_tensors) return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat) def _multiply_jacobian_transpose(self, loss_vecs): """Multiply vecs by the transpose Jacobian of losses.""" loss_vecs_flat = nest.flatten(loss_vecs) + # We stop gradients at wrt_tensors to produce partial derivatives (which is + # what we want for Jacobians). return gradients_impl.gradients( - self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat) + self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat, + stop_gradients=self._wrt_tensors) # Losses Fisher/Hessian multiplication functions: def _multiply_loss_fisher(self, loss_vecs): diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index c81086416c52d6ed828a8a8fda47a405124ff2b5..6e2c9ecdce7ad9f98a5beb016770ad2b1e197b0a 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -80,6 +80,12 @@ class FisherEstimator(object): self._layers = layer_collection self._layers.create_subgraph() self._check_registration(variables) + self._gradient_fns = { + "gradients": self._get_grads_lists_gradients, + "empirical": self._get_grads_lists_empirical, + "curvature_prop": self._get_grads_lists_curvature_prop, + "exact": self._get_grads_lists_exact + } setup = self._setup(cov_ema_decay) self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup @@ -201,75 +207,73 @@ class FisherEstimator(object): Raises: ValueError: If estimation_mode was improperly specified at construction. """ - damping = self.damping - fisher_blocks_list = self._layers.get_blocks() - tensors_to_compute_grads = [ fb.tensors_to_compute_grads() for fb in fisher_blocks_list ] - tensors_to_compute_grads_flat = nest.flatten(tensors_to_compute_grads) - - if self._estimation_mode == "gradients": - grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(), - tensors_to_compute_grads_flat) - grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) - grads_lists = tuple((grad,) for grad in grads_all) - - elif self._estimation_mode == "empirical": - grads_flat = gradients_impl.gradients(self._layers.total_loss(), - tensors_to_compute_grads_flat) - grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) - grads_lists = tuple((grad,) for grad in grads_all) - - elif self._estimation_mode == "curvature_prop": - loss_inputs = list(loss.inputs for loss in self._layers.losses) - loss_inputs_flat = nest.flatten(loss_inputs) - - transformed_random_signs = list(loss.multiply_fisher_factor( - utils.generate_random_signs(loss.fisher_factor_inner_shape)) - for loss in self._layers.losses) - - transformed_random_signs_flat = nest.flatten(transformed_random_signs) - - grads_flat = gradients_impl.gradients(loss_inputs_flat, - tensors_to_compute_grads_flat, - grad_ys - =transformed_random_signs_flat) - grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) - grads_lists = tuple((grad,) for grad in grads_all) - - elif self._estimation_mode == "exact": - # Loop over all coordinates of all losses. - grads_all = [] - for loss in self._layers.losses: - for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): - transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( - index) - grads_flat = gradients_impl.gradients(loss.inputs, - tensors_to_compute_grads_flat, - grad_ys=transformed_one_hot) - grads_all.append(nest.pack_sequence_as(tensors_to_compute_grads, - grads_flat)) - - grads_lists = zip(*grads_all) - - else: + + try: + grads_lists = self._gradient_fns[self._estimation_mode]( + tensors_to_compute_grads) + except KeyError: raise ValueError("Unrecognized value {} for estimation_mode.".format( self._estimation_mode)) for grads_list, fb in zip(grads_lists, fisher_blocks_list): - fb.instantiate_factors(grads_list, damping) + fb.instantiate_factors(grads_list, self.damping) cov_updates = [ factor.make_covariance_update_op(cov_ema_decay) for factor in self._layers.get_factors() ] - inv_updates = { - op.name: op - for factor in self._layers.get_factors() - for op in factor.make_inverse_update_ops() - } + inv_updates = {op.name: op for op in self._get_all_inverse_update_ops()} return control_flow_ops.group(*cov_updates), control_flow_ops.group( *inv_updates.values()), inv_updates + + def _get_all_inverse_update_ops(self): + for factor in self._layers.get_factors(): + for op in factor.make_inverse_update_ops(): + yield op + + def _get_grads_lists_gradients(self, tensors): + grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(), + nest.flatten(tensors)) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_grads_lists_empirical(self, tensors): + grads_flat = gradients_impl.gradients(self._layers.total_loss(), + nest.flatten(tensors)) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_transformed_random_signs(self): + transformed_random_signs = [] + for loss in self._layers.losses: + transformed_random_signs.append( + loss.multiply_fisher_factor( + utils.generate_random_signs(loss.fisher_factor_inner_shape))) + return transformed_random_signs + + def _get_grads_lists_curvature_prop(self, tensors): + loss_inputs = list(loss.inputs for loss in self._layers.losses) + transformed_random_signs = self._get_transformed_random_signs() + grads_flat = gradients_impl.gradients( + nest.flatten(loss_inputs), + nest.flatten(tensors), + grad_ys=nest.flatten(transformed_random_signs)) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_grads_lists_exact(self, tensors): + # Loop over all coordinates of all losses. + grads_all = [] + for loss in self._layers.losses: + for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): + transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( + index) + grads_flat = gradients_impl.gradients( + loss.inputs, nest.flatten(tensors), grad_ys=transformed_one_hot) + grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) + return zip(*grads_all) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 3bae45b32402c3ea60f3a82b99580d90dc150f86..7ef755c35ed8c75b7614ff5ffd92fb5319fadf65 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -12,7 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""FisherBlock definitions.""" +"""FisherBlock definitions. + +This library contains classes for estimating blocks in a model's Fisher +Information matrix. Suppose one has a model that parameterizes a posterior +distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its +Fisher Information matrix is given by, + + F(params) = E[ v(x, y, params) v(x, y, params)^T ] + +where, + + v(x, y, params) = (d / d params) log p(y | x, params) + +and the expectation is taken with respect to the data's distribution for 'x' and +the model's posterior distribution for 'y', + + x ~ p(x) + y ~ p(y | x, params) + +""" from __future__ import absolute_import from __future__ import division @@ -34,6 +53,14 @@ from tensorflow.python.ops import math_ops NORMALIZE_DAMPING_POWER = 1.0 +def set_global_constants(normalize_damping_power=None): + """Sets various global constants used by the classes in this module.""" + global NORMALIZE_DAMPING_POWER + + if normalize_damping_power is not None: + NORMALIZE_DAMPING_POWER = normalize_damping_power + + @six.add_metaclass(abc.ABCMeta) class FisherBlock(object): """Abstract base class for objects modeling approximate Fisher matrix blocks. @@ -87,6 +114,14 @@ class FisherBlock(object): """ pass + @abc.abstractproperty + def num_registered_minibatches(self): + """Number of minibatches registered for this FisherBlock. + + Typically equal to the number of towers in a multi-tower setup. + """ + pass + class FullFB(FisherBlock): """FisherBlock using a full matrix estimate (no approximations). @@ -125,8 +160,9 @@ class FullFB(FisherBlock): def multiply(self, vector): vector_flat = utils.tensors_to_column(vector) - out_flat = (math_ops.matmul(self._factor.get_cov(), vector_flat) + - self._damping * vector_flat) + out_flat = ( + math_ops.matmul(self._factor.get_cov(), vector_flat) + + self._damping * vector_flat) return utils.column_to_tensors(vector, out_flat) def full_fisher_block(self): @@ -136,6 +172,10 @@ class FullFB(FisherBlock): def tensors_to_compute_grads(self): return self._params + @property + def num_registered_minibatches(self): + return 1 # Multiple minibatches not supported. + class NaiveDiagonalFB(FisherBlock): """FisherBlock using a diagonal matrix approximation. @@ -181,62 +221,139 @@ class NaiveDiagonalFB(FisherBlock): def tensors_to_compute_grads(self): return self._params + @property + def num_registered_minibatches(self): + return 1 # Multiple minibatches not supported. + class FullyConnectedDiagonalFB(FisherBlock): """FisherBlock for fully-connected (dense) layers using a diagonal approx. - Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" estimator - that is computed using the well-known trick. - """ + Estimates the Fisher Information matrix's diagonal entries for a fully + connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of + squares" estimator. + + Let 'params' be a vector parameterizing a model and 'i' an arbitrary index + into it. We are interested in Fisher(params)[i, i]. This is, + + Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ] + + Consider fully connected layer in this model with (unshared) weight matrix + 'w'. For an example 'x' that produces layer inputs 'a' and output + preactivations 's', + + v(x, y, w) = vec( a (d loss / d s)^T ) - # TODO(jamesmartens): add units tests for this class + This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding + to the layer's parameters 'w'. + """ - def __init__(self, layer_collection, inputs, outputs, has_bias=False): + def __init__(self, layer_collection, has_bias=False): """Creates a FullyConnectedDiagonalFB block. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. - inputs: The Tensor of input activations to this layer. - outputs: The Tensor of output pre-activations from this layer. has_bias: Whether the component Kronecker factors have an additive bias. (Default: False) """ - self._inputs = inputs - self._outputs = outputs + self._inputs = [] + self._outputs = [] self._has_bias = has_bias super(FullyConnectedDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + self._damping = damping self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedDiagonalFactor, (self._inputs, grads_list, - self._has_bias)) + fisher_factors.FullyConnectedDiagonalFactor, + (inputs, grads_list, self._has_bias)) def multiply_inverse(self, vector): + """Approximate damped inverse Fisher-vector product. + + Args: + vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape + [input_size, output_size] corresponding to layer's weights. If not, a + 2-tuple of the former and a Tensor of shape [output_size] corresponding + to the layer's bias. + + Returns: + Tensor of the same shape, corresponding to the inverse Fisher-vector + product. + """ reshaped_vect = utils.layer_params_to_mat2d(vector) reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def multiply(self, vector): + """Approximate damped Fisher-vector product. + + Args: + vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape + [input_size, output_size] corresponding to layer's weights. If not, a + 2-tuple of the former and a Tensor of shape [output_size] corresponding + to the layer's bias. + + Returns: + Tensor of the same shape, corresponding to the Fisher-vector product. + """ reshaped_vect = utils.layer_params_to_mat2d(vector) reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def tensors_to_compute_grads(self): + """Tensors to compute derivative of loss with respect to.""" return self._outputs + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, input_size]. Inputs to the + matrix-multiply. + outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_minibatches(self): + result = len(self._inputs) + assert result == len(self._outputs) + return result + class ConvDiagonalFB(FisherBlock): """FisherBlock for convolutional layers using a diagonal approx. - Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" estimator. + Estimates the Fisher Information matrix's diagonal entries for a convolutional + layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" + estimator. + + Let 'params' be a vector parameterizing a model and 'i' an arbitrary index + into it. We are interested in Fisher(params)[i, i]. This is, + + Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ] + + Consider a convoluational layer in this model with (unshared) filter matrix + 'w'. For an example image 'x' that produces layer inputs 'a' and output + preactivations 's', + + v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T ) + + where 'loc' is a single (x, y) location in an image. + + This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding + to the layer's parameters 'w'. """ - # TODO(jamesmartens): add units tests for this class - def __init__(self, layer_collection, params, inputs, outputs, strides, - padding): + def __init__(self, layer_collection, params, strides, padding): """Creates a ConvDiagonalFB block. Args: @@ -246,37 +363,39 @@ class ConvDiagonalFB(FisherBlock): kernel alone, a Tensor of shape [kernel_height, kernel_width, in_channels, out_channels]. If kernel and bias, a tuple of 2 elements containing the previous and a Tensor of shape [out_channels]. - inputs: A Tensor of shape [batch_size, height, width, in_channels]. - Input activations to this layer. - outputs: A Tensor of shape [batch_size, height, width, out_channels]. - Output pre-activations from this layer. strides: The stride size in this layer (1-D Tensor of length 4). - padding: The padding in this layer (1-D of Tensor length 4). + padding: The padding in this layer (e.g. "SAME"). """ - self._inputs = inputs - self._outputs = outputs - self._strides = strides + self._inputs = [] + self._outputs = [] + self._strides = tuple(strides) if isinstance(strides, list) else strides self._padding = padding self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params self._filter_shape = tuple(fltr.shape.as_list()) - input_shape = tuple(inputs.shape.as_list()) - self._num_locations = (input_shape[1] * input_shape[2] - // (strides[1] * strides[2])) - super(ConvDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): + # Concatenate inputs, grads_list into single Tensors. + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + # Infer number of locations upon which convolution is applied. + inputs_shape = tuple(inputs.shape.as_list()) + self._num_locations = ( + inputs_shape[1] * inputs_shape[2] // + (self._strides[1] * self._strides[2])) + if NORMALIZE_DAMPING_POWER: damping /= self._num_locations ** NORMALIZE_DAMPING_POWER self._damping = damping self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvDiagonalFactor, - (self._inputs, grads_list, self._filter_shape, self._strides, - self._padding, self._has_bias)) + (inputs, grads_list, self._filter_shape, self._strides, self._padding, + self._has_bias)) def multiply_inverse(self, vector): reshaped_vect = utils.layer_params_to_mat2d(vector) @@ -291,6 +410,22 @@ class ConvDiagonalFB(FisherBlock): def tensors_to_compute_grads(self): return self._outputs + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to + the convolution. + outputs: Tensor of shape [batch_size, height, width, output_size]. Layer + preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_minibatches(self): + return len(self._inputs) + class KroneckerProductFB(FisherBlock): """A base class for FisherBlocks with separate input and output factors. @@ -337,10 +472,12 @@ class KroneckerProductFB(FisherBlock): left_factor = self._input_factor.get_cov() right_factor = self._output_factor.get_cov() reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = (math_ops.matmul(reshaped_vector, right_factor) + - self._output_damping * reshaped_vector) - reshaped_out = (math_ops.matmul(left_factor, reshaped_out) + - self._input_damping * reshaped_out) + reshaped_out = ( + math_ops.matmul(reshaped_vector, right_factor) + + self._output_damping * reshaped_vector) + reshaped_out = ( + math_ops.matmul(left_factor, reshaped_out) + + self._input_damping * reshaped_out) if self._renorm_coeff != 1.0: reshaped_out *= math_ops.cast( self._renorm_coeff, dtype=reshaped_out.dtype) @@ -367,34 +504,64 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB): K-FAC paper (https://arxiv.org/abs/1503.05671) """ - def __init__(self, layer_collection, inputs, outputs, has_bias=False): + def __init__(self, layer_collection, has_bias=False): """Creates a FullyConnectedKFACBasicFB block. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. - inputs: The Tensor of input activations to this layer. - outputs: The Tensor of output pre-activations from this layer. has_bias: Whether the component Kronecker factors have an additive bias. (Default: False) """ - self._inputs = inputs - self._outputs = outputs + self._inputs = [] + self._outputs = [] self._has_bias = has_bias super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, ((self._inputs,), - self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of Tensors. grads_list[i][j] is the + gradient of the loss with respect to 'outputs' from source 'i' and + tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + # TODO(b/68033310): Validate which of, + # (1) summing on a single device (as below), or + # (2) on each device in isolation and aggregating + # is faster. + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.FullyConnectedKroneckerFactor, # + ((inputs,), self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.FullyConnectedKroneckerFactor, # + (grads_list,)) self._register_damped_input_and_output_inverses(damping) def tensors_to_compute_grads(self): return self._outputs + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, input_size]. Inputs to the + matrix-multiply. + outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_minibatches(self): + return 1 # Multiple minibatches not supported. + class ConvKFCBasicFB(KroneckerProductFB): """FisherBlock for 2D convolutional layers using the basic KFC approx. @@ -430,8 +597,8 @@ class ConvKFCBasicFB(KroneckerProductFB): self._filter_shape = tuple(fltr.shape.as_list()) input_shape = tuple(inputs.shape.as_list()) - self._num_locations = (input_shape[1] * input_shape[2] // - (strides[1] * strides[2])) + self._num_locations = ( + input_shape[1] * input_shape[2] // (strides[1] * strides[2])) super(ConvKFCBasicFB, self).__init__(layer_collection) @@ -453,3 +620,34 @@ class ConvKFCBasicFB(KroneckerProductFB): def tensors_to_compute_grads(self): return self._outputs + + @property + def num_registered_minibatches(self): + return 1 # Multiple minibatches not supported. + + +def _concat_along_batch_dim(tensor_list): + """Concatenate tensors along batch (first) dimension. + + Args: + tensor_list: list of Tensors or list of tuples of Tensors. + + Returns: + Tensor or tuple of Tensors. + + Raises: + ValueError: If 'tensor_list' is empty. + + """ + if not tensor_list: + raise ValueError( + "Cannot concatenate Tensors if there are no Tensors to concatenate.") + + if isinstance(tensor_list[0], (tuple, list)): + # [(tensor1a, tensor1b), + # (tensor2a, tensor2b), ...] --> (tensor_a, tensor_b) + return tuple( + array_ops.concat(tensors, axis=0) for tensors in zip(*tensor_list)) + else: + # [tensor1, tensor2] --> tensor + return array_ops.concat(tensor_list, axis=0) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py index c6cc169b3784ca2e60cde6cd703f13ddeaaad985..59389f8d385c18f50914d690cfaa2825ef807ed3 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -31,7 +31,8 @@ _allowed_symbols = [ 'KroneckerProductFB', 'FullyConnectedKFACBasicFB', 'ConvKFCBasicFB', - 'ConvDiagonalFB' + 'ConvDiagonalFB', + 'set_global_constants', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index eacd9f53b1b1471ae6f77a35cbfcbb33d5434e2c..b8b524406c3cac41aa70047be5e601e5aaee8e01 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -33,9 +33,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages -# TODO(someone): come up with a better mechanism to set these constants -# externally. See b/67084987 - # Whether to initialize covariance estimators at a zero matrix (or the identity # matrix). INIT_COVARIANCES_AT_ZERO = False @@ -53,6 +50,25 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 EIGENVALUE_CLIPPING_THRESHOLD = 0.0 +def set_global_constants(init_covariances_at_zero=None, zero_debias=None, + eigenvalue_decomposition_threshold=None, + eigenvalue_clipping_threshold=None): + """Sets various global constants used by the classes in this module.""" + global INIT_COVARIANCES_AT_ZERO + global ZERO_DEBIAS + global EIGENVALUE_DECOMPOSITION_THRESHOLD + global EIGENVALUE_CLIPPING_THRESHOLD + + if init_covariances_at_zero is not None: + INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero + if zero_debias is not None: + ZERO_DEBIAS = zero_debias + if eigenvalue_decomposition_threshold is not None: + EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold + if eigenvalue_clipping_threshold is not None: + EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold + + def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument return array_ops.diag(array_ops.ones(shape[0], dtype)) @@ -412,11 +428,28 @@ class NaiveDiagonalFactor(DiagonalFactor): class FullyConnectedDiagonalFactor(DiagonalFactor): - """FisherFactor for a diagonal approx of a fully-connected layer's Fisher.""" + r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. + + Given in = [batch_size, input_size] and out_grad = [batch_size, output_size], + approximates the covariance as, + + Cov(in, out) = (1/batch_size) \sum_{i} outer(in[i], out_grad[i]) ** 2.0 + + where the square is taken element-wise. + """ # TODO(jamesmartens): add units tests for this class def __init__(self, inputs, outputs_grads, has_bias=False): + """Instantiate FullyConnectedDiagonalFactor. + + Args: + inputs: Tensor of shape [batch_size, input_size]. Inputs to fully + connected layer. + outputs_grads: List of Tensors of shape [batch_size, output_size]. + Gradient of loss with respect to layer's preactivations. + has_bias: bool. If True, append '1' to each input. + """ self._outputs_grads = outputs_grads self._batch_size = array_ops.shape(inputs)[0] self._orig_tensors_name = scope_string_from_params((inputs,) + @@ -540,6 +573,14 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): """ def __init__(self, tensors, has_bias=False): + """Instantiate FullyConnectedKroneckerFactor. + + Args: + tensors: List of Tensors of shape [batch_size, n]. Represents either a + layer's inputs or its output's gradients. + has_bias: bool. If True, assume this factor is for the layer's inputs and + append '1' to each row. + """ # The tensor argument is either a tensor of input activations or a tensor of # output pre-activation gradients. self._has_bias = has_bias diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py index 49a07b15986b946105d32a1950bcccabaa363cef..23ee93cd405bbf719939df89d525c812ee061f8b 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -40,6 +40,7 @@ _allowed_symbols = [ "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor", "ConvDiagonalFactor", + "set_global_constants", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 1b77f5d3ba9820167e406dff3d55ef7d46d7482c..2b9958a46a67621ee50f5b42fa3a9e374398f66d 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -27,6 +27,8 @@ from __future__ import print_function from collections import defaultdict from collections import OrderedDict +import six + from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb from tensorflow.contrib.kfac.python.ops import loss_functions as lf from tensorflow.contrib.kfac.python.ops import utils @@ -37,10 +39,15 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +# Names for various approximations that can be requested for Fisher blocks. APPROX_KRONECKER_NAME = "kron" APPROX_DIAGONAL_NAME = "diagonal" APPROX_FULL_NAME = "full" +# Possible value for 'reuse' keyword argument. Sets 'reuse' to +# tf.get_variable_scope().reuse. +VARIABLE_SCOPE = "VARIABLE_SCOPE" + # TODO(jamesmartens): need to add find_canonical_output back into this somewhere @@ -55,6 +62,7 @@ class LayerParametersDict(OrderedDict): super(LayerParametersDict, self).__init__(*args, **kwargs) def __setitem__(self, key, value): + key = self._canonicalize_key(key) tensors = key if isinstance(key, (tuple, list)) else (key,) key_collisions = self._tensors.intersection(tensors) if key_collisions: @@ -63,12 +71,26 @@ class LayerParametersDict(OrderedDict): super(LayerParametersDict, self).__setitem__(key, value) def __delitem__(self, key): + key = self._canonicalize_key(key) self._tensors.remove(key) super(LayerParametersDict, self).__delitem__(key) + def __getitem__(self, key): + key = self._canonicalize_key(key) + return super(LayerParametersDict, self).__getitem__(key) + + def __contains__(self, key): + key = self._canonicalize_key(key) + return super(LayerParametersDict, self).__contains__(key) -# TODO(duckworthd): add capability for LayerCollection to be "finalized" -# and do this when it gets used by FisherEstimator / KfacOptimizer + def _canonicalize_key(self, key): + if isinstance(key, (list, tuple)): + return tuple(key) + return key + + +# TODO(b/68034464): add capability for LayerCollection to be "finalized" +# and do this when it gets used by FisherEstimator / KfacOptimizer. class LayerCollection(object): @@ -94,13 +116,16 @@ class LayerCollection(object): self.fisher_factors = OrderedDict() self._generic_registrations = set() self._graph = graph or ops.get_default_graph() - self.losses = [] + self._loss_dict = {} # {str: LossFunction} self._subgraph = None with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name - reset_internals = __init__ + @property + def losses(self): + """LossFunctions registered with this LayerCollection.""" + return list(self._loss_dict.values()) def register_block(self, layer_key, fisher_block): """Validates and registers the layer_key associated with the fisher_block. @@ -193,10 +218,10 @@ class LayerCollection(object): def get_use_count_map(self): """Returns a dict of variables to their number of registrations.""" vars_to_uses = defaultdict(int) - for key in self.fisher_blocks.keys(): + for key, block in six.iteritems(self.fisher_blocks): key = key if isinstance(key, (tuple, list)) else (key,) for k in key: - vars_to_uses[k] += 1 + vars_to_uses[k] += block.num_registered_minibatches return vars_to_uses def get_blocks(self): @@ -234,18 +259,57 @@ class LayerCollection(object): params, inputs, outputs, - approx=APPROX_KRONECKER_NAME): + approx=APPROX_KRONECKER_NAME, + reuse=VARIABLE_SCOPE): + """Registers a fully connnected layer. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. + outputs: Tensor of shape [batch_size, output_size]. Preactivations + produced by layer. + approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If VARIABLE_SCOPE, use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + approx_to_block_types = { + APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, + APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, + } + + if approx not in approx_to_block_types: + raise ValueError("Bad value {} for approx.".format(approx)) + + block_type = approx_to_block_types[approx] has_bias = isinstance(params, (tuple, list)) - if approx == APPROX_KRONECKER_NAME: - self.register_block(params, - fb.FullyConnectedKFACBasicFB(self, inputs, outputs, - has_bias)) - elif approx == APPROX_DIAGONAL_NAME: - self.register_block(params, - fb.FullyConnectedDiagonalFB(self, inputs, outputs, - has_bias)) + + if reuse == VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse: + block = self.fisher_blocks.get(params, None) + if block is None: + raise KeyError( + "Reuse requested but no FisherBlock found for params {}.".format( + params)) + if not isinstance(block, block_type): + raise ValueError( + "Requested block of type {} but block of type {} already exists " + "for params {}.".format(block_type, type(block), params)) + else: - raise ValueError("Bad value {} for approx.".format(approx)) + block = block_type(self, has_bias) + self.register_block(params, block) + + block.register_additional_minibatch(inputs, outputs) def register_conv2d(self, params, strides, padding, inputs, outputs, approx=APPROX_KRONECKER_NAME): @@ -255,9 +319,9 @@ class LayerCollection(object): fb.ConvKFCBasicFB(self, params, inputs, outputs, strides, padding)) elif approx == APPROX_DIAGONAL_NAME: - self.register_block(params, - fb.ConvDiagonalFB(self, params, inputs, outputs, - strides, padding)) + block = fb.ConvDiagonalFB(self, params, strides, padding) + block.register_additional_minibatch(inputs, outputs) + self.register_block(params, block) def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME): params = params if isinstance(params, (tuple, list)) else (params,) @@ -277,7 +341,9 @@ class LayerCollection(object): def register_categorical_predictive_distribution(self, logits, seed=None, - targets=None): + targets=None, + name=None, + reuse=VARIABLE_SCOPE): """Registers a categorical predictive distribution. Args: @@ -288,16 +354,55 @@ class LayerCollection(object): total_loss() is required, for example, to estimate the "empirical Fisher" (instead of the true Fisher). (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock. + If False, create a new FisherBlock. If VARIABLE_SCOPE, use + tf.get_variable_scope().reuse. + + Raises: + ValueError: If reuse=True and name != None. + ValueError: If reuse=True and seed != None. + KeyError: If reuse=True and no existing LossFunction with 'name' found. + KeyError: If reuse=False and existing LossFunction with 'name' found. """ - loss = lf.CategoricalLogitsNegativeLogProbLoss( - logits, targets=targets, seed=seed) - self.losses.append(loss) + name = name or self._graph.unique_name( + "register_categorical_predictive_distribution") + + if reuse == VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse: + if name is None: + raise ValueError( + "If reuse is enabled, loss function's name must be set.") + if seed is not None: + raise ValueError( + "Seed can only be specified at LossFunction instantiation.") + + loss = self._loss_dict.get(name, None) + + if loss is None: + raise KeyError( + "Unable to find loss function named {}. Create a new LossFunction " + "with reuse=False.".format(name)) + + loss.register_additional_minibatch(logits, targets=targets) + else: + if name in self._loss_dict: + raise KeyError( + "Loss function named {} already exists. Set reuse=True to append " + "another minibatch.".format(name)) + loss = lf.CategoricalLogitsNegativeLogProbLoss( + logits, targets=targets, seed=seed) + self._loss_dict[name] = loss def register_normal_predictive_distribution(self, mean, var=0.5, seed=None, - targets=None): + targets=None, + name=None): """Registers a normal predictive distribution. Args: @@ -312,15 +417,23 @@ class LayerCollection(object): total_loss() is required, for example, to estimate the "empirical Fisher" (instead of the true Fisher). (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) """ + name = name or self._graph.unique_name( + "register_normal_predictive_distribution") + if name in self._loss_dict: + raise NotImplementedError( + "Adding logits to an existing LossFunction not yet supported.") loss = lf.NormalMeanNegativeLogProbLoss( mean, var, targets=targets, seed=seed) - self.losses.append(loss) + self._loss_dict[name] = loss def register_multi_bernoulli_predictive_distribution(self, logits, seed=None, - targets=None): + targets=None, + name=None): """Registers a multi-Bernoulli predictive distribution. Args: @@ -331,12 +444,40 @@ class LayerCollection(object): total_loss() is required, for example, to estimate the "empirical Fisher" (instead of the true Fisher). (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) """ + name = name or self._graph.unique_name( + "register_multi_bernoulli_predictive_distribution") + if name in self._loss_dict: + raise NotImplementedError( + "Adding logits to an existing LossFunction not yet supported.") loss = lf.MultiBernoulliNegativeLogProbLoss( logits, targets=targets, seed=seed) - self.losses.append(loss) + self._loss_dict[name] = loss def make_or_get_factor(self, cls, args): + """Insert 'cls(args)' into 'self.fisher_factors' if not already present. + + Wraps constructor in 'tf.variable_scope()' to ensure variables constructed + in 'cls.__init__' are placed under this LayerCollection's scope. + + Args: + cls: Class that implements FisherFactor. + args: Tuple of arguments to pass into 'cls's constructor. Must be + hashable. + + Returns: + Instance of 'cls' found in self.fisher_factors. + """ + try: + hash(args) + except TypeError: + raise TypeError(( + "Unable to use (cls, args) = ({}, {}) as a key in " + "LayerCollection.fisher_factors. The pair cannot be hashed." + ).format(cls, args)) + with variable_scope.variable_scope(self._var_scope): return utils.setdefault(self.fisher_factors, (cls, args), lambda: cls(*args)) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py index 63a9b173bc809a7f25b382a3639462c27b39c5f9..d6bf61a210203dd74d4e93b65005f660b1fab4ff 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -35,6 +35,7 @@ _allowed_symbols = [ "APPROX_KRONECKER_NAME", "APPROX_DIAGONAL_NAME", "APPROX_FULL_NAME", + "VARIABLE_SCOPE", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index d80382b9cf31d784d7d2267a18cf88362fea95fc..3cfde7f9ababab73980e93ea1dd65be1b559712b 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -42,8 +42,14 @@ class LossFunction(object): use this class. It depends on the use case. """ - def __init__(self, targets=None): - self._targets = targets + @abc.abstractproperty + def targets(self): + """The targets being predicted by the model. + + Returns: + None or Tensor of appropriate shape for calling self._evaluate() on. + """ + pass @abc.abstractproperty def inputs(self): @@ -51,16 +57,25 @@ class LossFunction(object): pass def evaluate(self): - """Evaluate the loss function.""" - if self._targets is not None: + """Evaluate the loss function on the targets.""" + if self.targets is not None: # We treat the targets as "constant". It's only the inputs that get # "back-propped" through. - return self._evaluate(array_ops.stop_gradient(self._targets)) + return self._evaluate(array_ops.stop_gradient(self.targets)) else: raise Exception("Cannot evaluate losses with unspecified targets.") @abc.abstractmethod def _evaluate(self, targets): + """Evaluates the log probability of the targets. + + Args: + targets: Tensor that distribution can calculate log_prob() of. + + Returns: + log probability of each target, summed across all targets. + """ + pass @abc.abstractmethod @@ -104,7 +119,7 @@ class LossFunction(object): @abc.abstractmethod def multiply_hessian_factor_transpose(self, vector): - """Right-multiply a vector by the tranpose of a factor B of the Hessian. + """Right-multiply a vector by the transpose of a factor B of the Hessian. Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) of the loss function with respect to its inputs. Typically this will be @@ -166,9 +181,9 @@ class LossFunction(object): class NegativeLogProbLoss(LossFunction): """Abstract base class for loss functions that are negative log probs.""" - def __init__(self, targets=None, seed=None): + def __init__(self, seed=None): self._default_seed = seed - super(NegativeLogProbLoss, self).__init__(targets=targets) + super(NegativeLogProbLoss, self).__init__() @property def inputs(self): @@ -176,6 +191,7 @@ class NegativeLogProbLoss(LossFunction): @abc.abstractproperty def params(self): + """Parameters to the underlying distribution.""" pass @abc.abstractmethod @@ -218,7 +234,7 @@ class NegativeLogProbLoss(LossFunction): @abc.abstractmethod def multiply_fisher_factor_transpose(self, vector): - """Right-multiply a vector by the tranpose of a factor B of the Fisher. + """Right-multiply a vector by the transpose of a factor B of the Fisher. Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying @@ -281,9 +297,18 @@ class NegativeLogProbLoss(LossFunction): @abc.abstractmethod def sample(self, seed): + """Sample 'targets' from the underlying distribution.""" pass def evaluate_on_sample(self, seed=None): + """Evaluates the log probability on a random sample. + + Args: + seed: int or None. Random seed for this draw from the distribution. + + Returns: + Log probability of sampled targets, summed across examples. + """ if seed is None: seed = self._default_seed # We treat the targets as "constant". It's only the inputs that get @@ -328,16 +353,19 @@ class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss): class DistributionNegativeLogProbLoss(NegativeLogProbLoss): """Base class for neg log prob losses that use the TF Distribution classes.""" - def __init__(self, dist, targets=None, seed=None): - self._dist = dist - super(DistributionNegativeLogProbLoss, self).__init__( - targets=targets, seed=seed) + def __init__(self, seed=None): + super(DistributionNegativeLogProbLoss, self).__init__(seed=seed) + + @abc.abstractproperty + def dist(self): + """The underlying tf.distributions.Distribution.""" + pass def _evaluate(self, targets): - return -math_ops.reduce_sum(self._dist.log_prob(targets)) + return -math_ops.reduce_sum(self.dist.log_prob(targets)) def sample(self, seed): - return self._dist.sample(seed=seed) + return self.dist.sample(seed=seed) class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, @@ -355,11 +383,18 @@ class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, """ def __init__(self, mean, var=0.5, targets=None, seed=None): - dist = normal.Normal(loc=mean, scale=var**0.5) self._mean = mean self._var = var - super(NormalMeanNegativeLogProbLoss, self).__init__( - dist, targets=targets, seed=seed) + self._targets = targets + super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var)) @property def params(self): @@ -397,7 +432,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): This class parameterizes a multivariate normal distribution with n independent dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not - assume the variance is held constant. The Fisher Information for for n = 1 + assume the variance is held constant. The Fisher Information for n = 1 is given by, F = [[1 / variance, 0], @@ -416,10 +451,16 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): self._mean = mean self._variance = variance self._scale = math_ops.sqrt(variance) - dist = normal.Normal(loc=self._mean, scale=self._scale) - super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(dist, - targets=targets, - seed=seed) + self._targets = targets + super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return normal.Normal(loc=self._mean, scale=self._scale) @property def params(self): @@ -534,12 +575,53 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, """ def __init__(self, logits, targets=None, seed=None): - dist = categorical.Categorical(logits=logits) - self._logits = logits - self._probs = dist.probs - self._sqrt_probs = math_ops.sqrt(self._probs) - super(CategoricalLogitsNegativeLogProbLoss, self).__init__( - dist, targets=targets, seed=seed) + """Instantiates a CategoricalLogitsNegativeLogProbLoss. + + Args: + logits: Tensor of shape [batch_size, output_size]. Parameters for + underlying distribution. + targets: None or Tensor of shape [output_size]. Each elements contains an + index in [0, output_size). + seed: int or None. Default random seed when sampling. + """ + self._logits_components = [] + self._targets_components = [] + self.register_additional_minibatch(logits, targets=targets) + super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) + + def register_additional_minibatch(self, logits, targets=None): + """Register an additiona minibatch's worth of parameters. + + Args: + logits: Tensor of shape [batch_size, output_size]. Parameters for + underlying distribution. + targets: None or Tensor of shape [batch_size, output_size]. Each row must + be a one-hot vector. + """ + self._logits_components.append(logits) + self._targets_components.append(targets) + + @property + def _logits(self): + return array_ops.concat(self._logits_components, axis=0) + + @property + def targets(self): + if all(target is None for target in self._targets_components): + return None + return array_ops.concat(self._targets_components, axis=0) + + @property + def dist(self): + return categorical.Categorical(logits=self._logits) + + @property + def _probs(self): + return self.dist.probs + + @property + def _sqrt_probs(self): + return math_ops.sqrt(self._probs) @property def params(self): @@ -595,12 +677,21 @@ class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, """ def __init__(self, logits, targets=None, seed=None): - dist = bernoulli.Bernoulli(logits=logits) self._logits = logits - self._probs = dist.probs + self._targets = targets + super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets - super(MultiBernoulliNegativeLogProbLoss, self).__init__( - dist, targets=targets, seed=seed) + @property + def dist(self): + return bernoulli.Bernoulli(logits=self._logits) + + @property + def _probs(self): + return self.dist.probs @property def params(self): @@ -632,11 +723,12 @@ class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): - """Inserts slice into a larger tensors of zeros. + """Inserts slice into a larger tensor of zeros. - Forms a new tensor that which is the same shape as slice_, except that + Forms a new tensor which is the same shape as slice_to_insert, except that the dimension given by 'dim' is expanded to the size given by 'dim_size'. - 'position' determines the position (index) of the slice in that dimension. + 'position' determines the position (index) at which to insert the slice within + that dimension. Assumes slice_to_insert.shape[dim] = 1. @@ -644,7 +736,7 @@ def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): slice_to_insert: The slice to insert. dim: The dimension which to expand with zeros. dim_size: The new size of the 'dim' dimension. - position: The position of 'slice_' in the new tensor. + position: The position of 'slice_to_insert' in the new tensor. Returns: The new tensor. @@ -662,4 +754,4 @@ def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): before[dim] = position after[dim] = dim_size - position - 1 - return array_ops.pad(slice_to_insert, zip(before, after)) + return array_ops.pad(slice_to_insert, list(zip(before, after))) diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py index 0617c5be4d9fc0412245bb12f5d1ff2ca3ee420a..831870fca451c585cb1a1dc6b24aad757e2bbaa8 100644 --- a/tensorflow/contrib/kfac/python/ops/op_queue.py +++ b/tensorflow/contrib/kfac/python/ops/op_queue.py @@ -61,7 +61,7 @@ class OpQueue(object): sess: tf.Session. Returns: - Next Op chosen from from 'ops'. + Next Op chosen from 'ops'. """ # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii') # returns a str. diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index b34b4e10adb549990b63e9726a88294d03ecb59a..a7473481e44da0b09c047db9af29032918ea6cef 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -250,7 +250,7 @@ def generate_random_signs(shape, dtype=dtypes.float32): return 2 * math_ops.cast(ints, dtype=dtype) - 1 -def fwd_gradients(ys, xs, grad_xs=None): +def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): """Compute forward-mode gradients.""" # See b/37888268. @@ -260,7 +260,8 @@ def fwd_gradients(ys, xs, grad_xs=None): # generated by the first gradients_impl.gradients call. us = [array_ops.zeros_like(y) + float("nan") for y in ys] - dydxs = gradients_impl.gradients(ys, xs, grad_ys=us) + dydxs = gradients_impl.gradients(ys, xs, grad_ys=us, + stop_gradients=stop_gradients) # Deal with strange types that gradients_impl.gradients returns but can't # deal with. diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index 4eba29caecbddc408d168158daf8377aedab7bcc..894e6f6946bb59810a9da2d304cc0dd43d25201d 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -109,9 +109,9 @@ py_test( ":test_util", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:session", ], ) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index bbb4fb1f57b54848e538d0cd1fad90ce0b6feab0..1ae4d281c403d47e12b8d89ff30a31b3621b9223 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -153,10 +153,10 @@ py_test( deps = [ ":layers_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", + "//tensorflow/python:session", "//third_party/py/numpy", ], ) @@ -168,9 +168,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":layers_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:session", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//third_party/py/numpy", @@ -238,6 +238,7 @@ py_test( ":layers_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:lookup_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", @@ -280,9 +281,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":layers_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:session", "//tensorflow/python:variables", ], ) @@ -294,9 +295,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":layers_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:session", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//third_party/py/numpy", diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index d8ab7c2d70d8a7346c04d326f3a51b40a4f900ea..d309ba958ded86afdc1e4bba2ff471a5181cda4e 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -47,6 +47,7 @@ See the @{$python/contrib.layers} guide. @@separable_conv2d @@separable_convolution2d @@softmax +@@spatial_softmax @@stack @@unit_norm @@bow_encoder diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 29ab281b1a603df153619eed2336420ddde9f6a8..deeafdf300f4e060f8c840b7ccd0e18888f68638 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -463,7 +463,8 @@ def batch_norm(inputs, scope=None, renorm=False, renorm_clipping=None, - renorm_decay=0.99): + renorm_decay=0.99, + adjustment=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing @@ -546,6 +547,17 @@ def batch_norm(inputs, and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `decay` is still applied to get the means and variances for inference. + adjustment: A function taking the `Tensor` containing the (dynamic) shape of + the input tensor and returning a pair (scale, bias) to apply to the + normalized values (before gamma and beta), only during training. For + example, + `adjustment = lambda shape: ( + tf.random_uniform(shape[-1:], 0.93, 1.07), + tf.random_uniform(shape[-1:], -0.1, 0.1))` + will scale the normalized value by up to 7% up or down, then shift the + result by up to 0.1 (with independent scaling and bias for each feature + but shared across all examples), and finally apply gamma and/or beta. If + `None`, no adjustment is applied. Returns: A `Tensor` representing the output of the operation. @@ -569,7 +581,10 @@ def batch_norm(inputs, # implementation in normalization_layers.BatchNormalization. inputs = ops.convert_to_tensor(inputs) rank = inputs.get_shape().ndims - possible_to_fuse = batch_weights is None and not renorm and rank in [2, 4] + possible_to_fuse = (batch_weights is None and + not renorm and + rank in [2, 4] and + adjustment is None) if fused and possible_to_fuse and ( zero_debias_moving_mean or rank == 2 or updates_collections is not ops.GraphKeys.UPDATE_OPS): @@ -636,6 +651,7 @@ def batch_norm(inputs, renorm=renorm, renorm_clipping=renorm_clipping, renorm_momentum=renorm_decay, + adjustment=adjustment, name=sc.name, _scope=sc, _reuse=reuse, diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 1040ad3ca7a4bbd56584f8e2cb8b2a2c8029d418..7c77e905f7432db4e42e7fda70aa72f32f40bb09 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -2644,6 +2644,26 @@ class BatchNormTest(test.TestCase): zero_debias_moving_mean=True) sess.run(variables_lib.global_variables_initializer()) + def testAdjustmentCreated(self): + # Tests that the adjustment is appropriately passed to and used by the core + # BN layer. + all_adjustments = [] + def _create_adjustment(shape): + adjustments = [array_ops.ones(shape[-1:]), array_ops.zeros(shape[-1:])] + all_adjustments.extend(adjustments) + return adjustments + depth = 8 + images = array_ops.zeros([10, 5, 5, depth]) + output = _layers.batch_norm( + images, + is_training=True, + adjustment=_create_adjustment) + self.assertListEqual(output.shape.as_list(), images.shape.as_list()) + self.assertEqual(len(all_adjustments), 2) + self.assertListEqual(all_adjustments[0].shape.as_list(), [depth]) + self.assertListEqual(all_adjustments[1].shape.as_list(), [depth]) + + class LayerNormTest(test.TestCase): def testUnknownShape(self): diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index 33db93b9704eb3c81d042e2636f916d5f685ad97..cdceea6fee5bdb5aeb6537ea55d25ccf107def4c 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -41,7 +41,7 @@ OPTIMIZER_CLS_NAMES = { "Adagrad": train.AdagradOptimizer, "Adam": train.AdamOptimizer, "Ftrl": train.FtrlOptimizer, - "Momentum": train.MomentumOptimizer, + "Momentum": lambda lr: train.MomentumOptimizer(lr, momentum=0.9), "RMSProp": train.RMSPropOptimizer, "SGD": train.GradientDescentOptimizer, } diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 8813a99f1994ade17cca3b1371a17278e434cef9..1ea25bd1a5685eb6f840e621b5739029a660aa0f 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -176,7 +176,7 @@ class OptimizersTest(test.TestCase): session.run(train, feed_dict={x: 5}) var_value, global_step_value = session.run([var, global_step]) # Due to randomness the following number may change if graph is different. - self.assertAlmostEqual(var_value, 8.5591021, 4) + self.assertAlmostEqual(var_value, 9.86912, 4) self.assertEqual(global_step_value, 1) def testGradientNoiseWithClipping(self): @@ -193,7 +193,7 @@ class OptimizersTest(test.TestCase): variables.global_variables_initializer().run() session.run(train, feed_dict={x: 5}) var_value, global_step_value = session.run([var, global_step]) - self.assertAlmostEqual(var_value, 9.0, 4) + self.assertAlmostEqual(var_value, 9.86912, 4) self.assertEqual(global_step_value, 1) def testGradientClip(self): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index f3949beed04655456b3f0b550f5c757c85899270..ac615b120c16d5d9a7798874653f8f00f8fd15b4 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -768,7 +768,7 @@ py_test( ":learn", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/session_bundle:exporter", - "//tensorflow/contrib/session_bundle:manifest_proto_py", + "//tensorflow/contrib/session_bundle:manifest_proto_py_pb2", "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index 2fec0508a5603768a301e1e2f9c251a89cf0ef69..12f9bba531a296a00d17956b8ce32e5d7dead380 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -348,6 +348,12 @@ class DNNClassifierTest(test.TestCase): for prediction in predictions: self.assertIn(prediction, (0, 1)) + def _assertClassificationPredictions( + self, expected_len, n_classes, predictions): + self.assertEqual(expected_len, len(predictions)) + for prediction in predictions: + self.assertIn(prediction, range(n_classes)) + def _assertProbabilities(self, expected_batch_size, expected_n_classes, probabilities): self.assertEqual(expected_batch_size, len(probabilities)) @@ -732,7 +738,7 @@ class DNNClassifierTest(test.TestCase): self.assertIn('loss', scores) predicted_classes = classifier.predict_classes( input_fn=_input_fn, as_iterable=False) - self._assertBinaryPredictions(3, predicted_classes) + self._assertClassificationPredictions(3, n_classes, predicted_classes) predictions = classifier.predict(input_fn=_input_fn, as_iterable=False) self.assertAllEqual(predicted_classes, predictions) probabilities = classifier.predict_proba( @@ -765,8 +771,9 @@ class DNNClassifierTest(test.TestCase): feature_column.real_valued_column('age') ] + n_classes = 3 classifier = dnn.DNNClassifier( - n_classes=3, + n_classes=n_classes, feature_columns=feature_columns, hidden_units=[3, 3], config=run_config.RunConfig(tf_random_seed=1)) @@ -780,7 +787,7 @@ class DNNClassifierTest(test.TestCase): predicted_classes = list( classifier.predict_classes( input_fn=predict_input_fn, as_iterable=True)) - self.assertListEqual(predicted_classes, [1, 0, 0]) + self._assertClassificationPredictions(3, n_classes, predicted_classes) predictions = list( classifier.predict( input_fn=predict_input_fn, as_iterable=True)) @@ -788,8 +795,7 @@ class DNNClassifierTest(test.TestCase): predicted_proba = list( classifier.predict_proba( input_fn=predict_input_fn, as_iterable=True)) - self.assertAllClose( - predicted_proba, [[0., 1., 0.], [1., 0., 0.], [1., 0., 0.]], atol=0.3) + self._assertProbabilities(3, n_classes, predicted_proba) def testCustomMetrics(self): """Tests custom evaluation metrics.""" @@ -1214,6 +1220,12 @@ class DNNRegressorTest(test.TestCase): scores = regressor.evaluate(input_fn=_input_fn_eval, steps=1) self.assertIn('loss', scores) + def _assertRegressionOutputs( + self, predictions, expected_shape): + predictions_nparray = np.array(predictions) + self.assertAllEqual(expected_shape, predictions_nparray.shape) + self.assertTrue(np.issubdtype(predictions_nparray.dtype, np.float)) + def testPredict_AsIterableFalse(self): """Tests predict method with as_iterable=False.""" labels = [1., 0., 0.2] @@ -1252,7 +1264,7 @@ class DNNRegressorTest(test.TestCase): self.assertIn('loss', scores) predicted_scores = regressor.predict_scores( input_fn=_input_fn, as_iterable=False) - self.assertAllClose(labels, predicted_scores, atol=0.2) + self._assertRegressionOutputs(predicted_scores, [3]) predictions = regressor.predict(input_fn=_input_fn, as_iterable=False) self.assertAllClose(predicted_scores, predictions) @@ -1296,7 +1308,7 @@ class DNNRegressorTest(test.TestCase): predicted_scores = list( regressor.predict_scores( input_fn=predict_input_fn, as_iterable=True)) - self.assertAllClose(labels, predicted_scores, atol=0.2) + self._assertRegressionOutputs(predicted_scores, [3]) predictions = list( regressor.predict(input_fn=predict_input_fn, as_iterable=True)) self.assertAllClose(predicted_scores, predictions) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 1724d7599d09873f969555cc9382c0753eba463f..69440e823ef1ed2d739f28bc14587891f2de80bb 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -639,7 +639,7 @@ class DynamicRnnEstimator(estimator.Estimator): ValueError: `problem_type` is not one of `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`. ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but - `num_classes` is not specifieProblemType + `num_classes` is not specified. ValueError: `prediction_type` is not one of `PredictionType.MULTIPLE_VALUE` or `PredictionType.SINGLE_VALUE`. """ diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 8bb1c83a451d7cd27f4df04f983cdd23d1e136ae..788d2d0b1a58fad16712c968593b40de0d3979f0 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -981,9 +981,8 @@ class BaseEstimator( global_step = training_util.create_global_step(g) features, labels = input_fn() self._check_inputs(features, labels) - global_step_read_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access - with ops.control_dependencies([global_step_read_tensor]): - model_fn_ops = self._get_train_ops(features, labels) + training_util._get_or_create_global_step_read() # pylint: disable=protected-access + model_fn_ops = self._get_train_ops(features, labels) ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss) all_hooks.extend(hooks) all_hooks.extend([ diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index a67694d1c93c9d01bf63fc216b83d87ab390c456..468d792a0dccf5cf046a41ed8e1600940a15ac37 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -33,7 +33,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import logging_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib @@ -635,10 +634,11 @@ def _create_model_fn_ops(features, if (mode != model_fn.ModeKeys.INFER) and (labels is not None): weight_tensor = _weight_tensor(features, weight_column_name) loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor) - # Uses the deprecated API to set the tag explicitly. - # Without it, training and eval losses will show up in different graphs. - logging_ops.scalar_summary( - _summary_key(head_name, mkey.LOSS), weighted_average_loss) + # The name_scope escapism is needed to maintain the same summary tag + # after switching away from the now unsupported API. + with ops.name_scope(""): + summary_loss = array_ops.identity(weighted_average_loss) + summary.scalar(_summary_key(head_name, mkey.LOSS), summary_loss) if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: @@ -1484,8 +1484,12 @@ class _LossOnlyHead(Head): loss = self._loss_fn() if isinstance(loss, list): loss = math_ops.add_n(loss) - logging_ops.scalar_summary( - _summary_key(self.head_name, mkey.LOSS), loss) + # The name_scope escapism is needed to maintain the same summary tag + # after switching away from the now unsupported API. + with ops.name_scope(""): + summary_loss = array_ops.identity(loss) + summary.scalar(_summary_key(self.head_name, mkey.LOSS), + summary_loss) if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError("train_op_fn can not be None in TRAIN mode") diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index b4d9c3fc6fb5906de93950e46a41fb97b24f779f..992b804f59ecd88fedc2fba10d3079f93c4fe83d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of k-means clustering on top of `Estimator` API.""" +"""Implementation of k-means clustering on top of `Estimator` API. + +This module is deprecated. Please use +@{tf.contrib.factorization.KMeansClustering} instead of +@{tf.contrib.learn.KMeansClustering}. It has a similar interface, but uses the +@{tf.estimator.Estimator} API instead of @{tf.contrib.learn.Estimator}. +""" from __future__ import absolute_import from __future__ import division @@ -29,12 +35,17 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops -from tensorflow.python.summary import summary from tensorflow.python.ops.control_flow_ops import with_dependencies from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training.session_run_hook import SessionRunArgs +from tensorflow.python.util.deprecation import deprecated + +_USE_TF_CONTRIB_FACTORIZATION = ( + 'Please use tf.contrib.factorization.KMeansClustering instead of' + ' tf.contrib.learn.KMeansClustering. It has a similar interface, but uses' + ' the tf.estimator.Estimator API instead of tf.contrib.learn.Estimator.') class _LossRelativeChangeHook(session_run_hook.SessionRunHook): @@ -106,7 +117,7 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config): """Model function for KMeansClustering estimator.""" assert labels is None, labels (all_scores, model_predictions, losses, - is_initialized, _, init_op, training_op) = clustering_ops.KMeans( + is_initialized, init_op, training_op) = clustering_ops.KMeans( _parse_tensor_or_dict(features), params.get('num_clusters'), initial_clusters=params.get('training_initial_clusters'), @@ -153,6 +164,7 @@ class KMeansClustering(estimator.Estimator): ALL_SCORES = 'all_scores' LOSS_OP_NAME = 'kmeans_loss' + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def __init__(self, num_clusters, model_dir=None, @@ -204,6 +216,7 @@ class KMeansClustering(estimator.Estimator): model_dir=model_dir, config=config) + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def predict_cluster_idx(self, input_fn=None): """Yields predicted cluster indices.""" key = KMeansClustering.CLUSTER_IDX @@ -212,6 +225,7 @@ class KMeansClustering(estimator.Estimator): for result in results: yield result[key] + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def score(self, input_fn=None, steps=None): """Predict total sum of distances to nearest clusters. @@ -229,6 +243,7 @@ class KMeansClustering(estimator.Estimator): self.evaluate( input_fn=input_fn, steps=steps)[KMeansClustering.SCORES]) + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def transform(self, input_fn=None, as_iterable=False): """Transforms each element to distances to cluster centers. @@ -255,6 +270,7 @@ class KMeansClustering(estimator.Estimator): else: return results + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def clusters(self): """Returns cluster centers.""" return super(KMeansClustering, self).get_variable_value(self.CLUSTERS) diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 9b55826e627c5198ba7f88505afb866a0f308553..307db76afe20a7743df16d169270a6f319497eb6 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -149,16 +149,16 @@ class Experiment(object): Args: estimator: Object implementing Estimator interface, which could be a - combination of ${tf.contrib.learn.Trainable} and - ${tf.contrib.learn.Evaluable} (deprecated), or - ${tf.estimator.`Estimator}. + combination of @{tf.contrib.learn.Trainable} and + @{tf.contrib.learn.Evaluable} (deprecated), or + @{tf.estimator.Estimator}. train_input_fn: function, returns features and labels for training. eval_input_fn: function, returns features and labels for evaluation. If `eval_steps` is `None`, this should be configured only to produce for a finite number of batches (generally, 1 epoch over the evaluation data). eval_metrics: `dict` of string, metric function. If `None`, default set is used. This should be `None` if the `estimator` is - ${tf.estimator.Estimator}. If metrics are provided they will be + @{tf.estimator.Estimator}. If metrics are provided they will be *appended* to the default set. train_steps: Perform this many steps of training. `None`, the default, means train forever. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py index eaf6ae4ed72148436c3d1aa3838b516c6025b0aa..82848be7df653dd60219317d28f233767746f544 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py @@ -42,16 +42,6 @@ class DataFeederTest(test.TestCase): with self.assertRaisesRegexp(TypeError, 'annot convert'): data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1) - def test_input_uint32(self): - data = np.matrix([[1, 2], [3, 4]], dtype=np.uint32) - self._assert_raises(data) - self._assert_raises(self._wrap_dict(data)) - - def test_input_uint64(self): - data = np.matrix([[1, 2], [3, 4]], dtype=np.uint64) - self._assert_raises(data) - self._assert_raises(self._wrap_dict(data)) - def _assert_dtype(self, expected_np_dtype, expected_tf_dtype, input_data): feeder = data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1) if isinstance(input_data, dict): @@ -87,6 +77,16 @@ class DataFeederTest(test.TestCase): self._assert_dtype(np.int64, dtypes.int64, data) self._assert_dtype(np.int64, dtypes.int64, self._wrap_dict(data)) + def test_input_uint32(self): + data = np.matrix([[1, 2], [3, 4]], dtype=np.uint32) + self._assert_dtype(np.uint32, dtypes.uint32, data) + self._assert_dtype(np.uint32, dtypes.uint32, self._wrap_dict(data)) + + def test_input_uint64(self): + data = np.matrix([[1, 2], [3, 4]], dtype=np.uint64) + self._assert_dtype(np.uint64, dtypes.uint64, data) + self._assert_dtype(np.uint64, dtypes.uint64, self._wrap_dict(data)) + def test_input_uint8(self): data = np.matrix([[1, 2], [3, 4]], dtype=np.uint8) self._assert_dtype(np.uint8, dtypes.uint8, data) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index bdb88b89bb3dba95a229724994874b0a26b1fc3f..4b34fc62849766370979bb2002d42ee03ea7161a 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -442,7 +442,8 @@ def read_keyed_batch_features(file_pattern, feature_queue_capacity=100, num_enqueue_threads=2, parse_fn=None, - name=None): + name=None, + read_batch_size=None): """Adds operations to read, queue, batch and parse `Example` protos. Given file pattern (or list of files), will setup a queue for file names, @@ -482,6 +483,8 @@ def read_keyed_batch_features(file_pattern, parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. If `None`, defaults to `batch_size`. Returns: Returns tuple of: @@ -493,6 +496,7 @@ def read_keyed_batch_features(file_pattern, """ with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope: + if read_batch_size is None: read_batch_size = batch_size keys, examples = read_keyed_batch_examples( file_pattern, batch_size, @@ -501,7 +505,7 @@ def read_keyed_batch_features(file_pattern, num_epochs=num_epochs, queue_capacity=queue_capacity, num_threads=reader_num_threads, - read_batch_size=batch_size, + read_batch_size=read_batch_size, parse_fn=parse_fn, name=scope) # Parse the example. @@ -727,7 +731,8 @@ def read_batch_features(file_pattern, reader_num_threads=1, num_enqueue_threads=2, parse_fn=None, - name=None): + name=None, + read_batch_size=None): """Adds operations to read, queue, batch and parse `Example` protos. Given file pattern (or list of files), will setup a queue for file names, @@ -768,6 +773,8 @@ def read_batch_features(file_pattern, parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. If `None`, defaults to `batch_size`. Returns: A dict of `Tensor` or `SparseTensor` objects for each in `features`. @@ -786,6 +793,7 @@ def read_batch_features(file_pattern, reader_num_threads=reader_num_threads, feature_queue_capacity=feature_queue_capacity, num_enqueue_threads=num_enqueue_threads, + read_batch_size=read_batch_size, parse_fn=parse_fn, name=name) return features diff --git a/tensorflow/contrib/learn/python/learn/learn_runner.py b/tensorflow/contrib/learn/python/learn/learn_runner.py index 9f9740ec492e8b71191aff17f70a007409525ccd..2af723a0d64822e81fa0fbeb106ab812de6ab4e8 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner.py @@ -165,7 +165,7 @@ def run(experiment_fn, output_dir=None, schedule=None, run_config=None, must be None. 2) It accepts two arguments `run_config` and `hparams`, which should be used to create the `Estimator` (`run_config` passed as `config` to its - constructor; `hparams` used as the hyper-paremeters of the model). + constructor; `hparams` used as the hyper-parameters of the model). It must return an `Experiment`. For this case, `output_dir` must be None. output_dir: Base output directory [Deprecated]. schedule: The name of the method in the `Experiment` to run. diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index ee8856ac34219971b886a9f2c4b4e8f2ae639697..49413092a6bae547ddd2cad272b1abb3af1de046 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -50,6 +50,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.training import saver from tensorflow.python.util import compat @@ -108,7 +109,11 @@ def build_standardized_signature_def(input_tensors, output_tensors, classes = _get_classification_classes(output_tensors) scores = _get_classification_scores(output_tensors) if classes is None and scores is None: - (_, classes), = output_tensors.items() + items = list(output_tensors.items()) + if items[0][1].dtype == dtypes.string: + (_, classes), = items + else: + (_, scores), = items return signature_def_utils.classification_signature_def( examples, classes, scores) elif _is_regression_problem(problem_type, input_tensors, output_tensors): @@ -616,7 +621,13 @@ def make_best_model_export_strategy(serving_input_fn, Returns: The string path to the exported directory. """ - + if not checkpoint_path: + # TODO(b/67425018): switch to + # checkpoint_path = estimator.latest_checkpoint() + # as soon as contrib is cleaned up and we can thus be sure that + # estimator is a tf.estimator.Estimator and not a + # tf.contrib.learn.Estimator + checkpoint_path = saver.latest_checkpoint(estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index 810a3d34eee0a886fcf49ca3209547c9307a6e67..734bac17dc82a61fd4c85b6277625d4a35961958 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -10,152 +10,7 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_tests") - -cuda_py_tests( - name = "linear_operator_test", - size = "small", - srcs = ["python/kernel_tests/linear_operator_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "linear_operator_addition_test", - size = "small", - srcs = ["python/kernel_tests/linear_operator_addition_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "linear_operator_composition_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_composition_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], - tags = ["noasan"], # times out b/63678675 -) - -cuda_py_tests( - name = "linear_operator_diag_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_diag_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - ], -) - -cuda_py_tests( - name = "linear_operator_identity_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_identity_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - ], -) - -cuda_py_tests( - name = "linear_operator_full_matrix_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_full_matrix_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "linear_operator_tril_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_tril_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "linear_operator_udvh_update_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_udvh_update_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], - shard_count = 5, -) - -cuda_py_tests( - name = "linear_operator_util_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_util_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "linalg_py", @@ -176,11 +31,29 @@ py_library( "//tensorflow/python:random_seed", "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/ops/linalg", "//third_party/py/numpy", "@six_archive//:six", ], ) +cuda_py_test( + name = "linear_operator_addition_test", + size = "small", + srcs = ["python/kernel_tests/linear_operator_addition_test.py"], + additional_deps = [ + ":linalg_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py index 44421a6b7de0344a9a4a172ddc0900a44eb74450..4720692c3384ba1bede1f486c1b1e0e69d10a63a 100644 --- a/tensorflow/contrib/linalg/__init__.py +++ b/tensorflow/contrib/linalg/__init__.py @@ -21,8 +21,8 @@ See the @{$python/contrib.linalg} guide. @@LinearOperatorIdentity @@LinearOperatorScaledIdentity @@LinearOperatorFullMatrix -@@LinearOperatorTriL -@@LinearOperatorUDVHUpdate +@@LinearOperatorLowerTriangular +@@LinearOperatorLowRankUpdate @@LinearOperatorComposition @@add_operators @@ -33,14 +33,14 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member -from tensorflow.contrib.linalg.python.ops.linear_operator import * from tensorflow.contrib.linalg.python.ops.linear_operator_addition import * -from tensorflow.contrib.linalg.python.ops.linear_operator_composition import * -from tensorflow.contrib.linalg.python.ops.linear_operator_diag import * -from tensorflow.contrib.linalg.python.ops.linear_operator_full_matrix import * -from tensorflow.contrib.linalg.python.ops.linear_operator_identity import * -from tensorflow.contrib.linalg.python.ops.linear_operator_tril import * -from tensorflow.contrib.linalg.python.ops.linear_operator_udvh_update import * +from tensorflow.python.ops.linalg.linear_operator import * +from tensorflow.python.ops.linalg.linear_operator_composition import * +from tensorflow.python.ops.linalg.linear_operator_diag import * +from tensorflow.python.ops.linalg.linear_operator_full_matrix import * +from tensorflow.python.ops.linalg.linear_operator_identity import * +from tensorflow.python.ops.linalg.linear_operator_low_rank_update import * +from tensorflow.python.ops.linalg.linear_operator_lower_triangular import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py index 474648475579d3f840d18ba3dd3291b90755a93a..6a72df6dfd8d8c35211bab42b240b83d77160a02 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py @@ -19,10 +19,10 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib import linalg as linalg_lib from tensorflow.contrib.linalg.python.ops import linear_operator_addition from tensorflow.python.framework import random_seed from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops.linalg import linalg as linalg_lib from tensorflow.python.platform import test linalg = linalg_lib @@ -114,7 +114,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase): def test_diag_tril_diag(self): op1 = linalg.LinearOperatorDiag( [1., 1.], is_non_singular=True, name="diag_a") - op2 = linalg.LinearOperatorTriL( + op2 = linalg.LinearOperatorLowerTriangular( [[2., 0.], [0., 2.]], is_self_adjoint=True, is_non_singular=True, @@ -125,7 +125,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase): op_sum = add_operators([op1, op2, op3]) self.assertEqual(1, len(op_sum)) op = op_sum[0] - self.assertTrue(isinstance(op, linalg_lib.LinearOperatorTriL)) + self.assertTrue(isinstance(op, linalg_lib.LinearOperatorLowerTriangular)) self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval()) # The diag operators will be self-adjoint (because real and diagonal). @@ -140,7 +140,8 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase): op0 = linalg.LinearOperatorFullMatrix( [[-1., -1.], [-1., -1.]], name="matrix") op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a") - op2 = linalg.LinearOperatorTriL([[2., 0.], [1.5, 2.]], name="tril") + op2 = linalg.LinearOperatorLowerTriangular( + [[2., 0.], [1.5, 2.]], name="tril") op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b") with self.test_session(): op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator") @@ -189,7 +190,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): def test_tier_1_additions_done_by_tier_1(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorTriL([[1.]]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [linear_operator_addition._AddAndReturnTriL()], @@ -199,12 +200,12 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): # _BadAdder) was never reached. op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorTriL)) + self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular)) def test_tier_1_additions_done_by_tier_1_with_order_flipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorTriL([[1.]]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnTriL()], [linear_operator_addition._AddAndReturnDiag()], @@ -216,12 +217,12 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): # Tier 2 was never used (therefore, _BadAdder didn't raise). op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorTriL)) + self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular)) def test_cannot_add_everything_so_return_more_than_one_operator(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([2.]) - tril5 = linalg.LinearOperatorTriL([[5.]]) + tril5 = linalg.LinearOperatorLowerTriangular([[5.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], ] @@ -237,7 +238,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): if isinstance(op, linalg.LinearOperatorDiag): found_diag = True self.assertAllClose([[3.]], op.to_dense().eval()) - if isinstance(op, linalg.LinearOperatorTriL): + if isinstance(op, linalg.LinearOperatorLowerTriangular): found_tril = True self.assertAllClose([[5.]], op.to_dense().eval()) self.assertTrue(found_diag and found_tril) @@ -245,7 +246,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): def test_intermediate_tier_is_not_skipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorTriL([[1.]]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [_BadAdder()], @@ -369,14 +370,14 @@ class AddAndReturnTriLTest(test.TestCase): def test_diag_plus_tril(self): diag = linalg.LinearOperatorDiag([1., 2.]) - tril = linalg.LinearOperatorTriL([[10., 0.], [30., 0.]]) + tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]]) hints = linear_operator_addition._Hints( is_positive_definite=True, is_non_singular=True) self.assertTrue(self._adder.can_add(diag, diag)) self.assertTrue(self._adder.can_add(diag, tril)) operator = self._adder.add(diag, tril, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorTriL)) + self.assertTrue(isinstance(operator, linalg.LinearOperatorLowerTriangular)) with self.test_session(): self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval()) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py index 16c4c6e6d67f17d1674b8d1d39f006bc688bc6ce..86130a2c077ce14a7539b281ec809029bc05e071 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py @@ -22,14 +22,14 @@ import abc import six -from tensorflow.contrib.linalg.python.ops import linear_operator -from tensorflow.contrib.linalg.python.ops import linear_operator_diag -from tensorflow.contrib.linalg.python.ops import linear_operator_full_matrix -from tensorflow.contrib.linalg.python.ops import linear_operator_identity -from tensorflow.contrib.linalg.python.ops import linear_operator_tril from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.ops.linalg import linear_operator_diag +from tensorflow.python.ops.linalg import linear_operator_full_matrix +from tensorflow.python.ops.linalg import linear_operator_identity +from tensorflow.python.ops.linalg import linear_operator_lower_triangular __all__ = [] @@ -347,7 +347,7 @@ class _AddAndReturnTriL(_Adder): else: op_add_to_tensor, op_other = op2, op1 - return linear_operator_tril.LinearOperatorTriL( + return linear_operator_lower_triangular.LinearOperatorLowerTriangular( tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), is_non_singular=hints.is_non_singular, is_self_adjoint=hints.is_self_adjoint, @@ -397,7 +397,8 @@ def _type(operator): """Returns the type name constant (e.g. _TRIL) for operator.""" if isinstance(operator, linear_operator_diag.LinearOperatorDiag): return _DIAG - if isinstance(operator, linear_operator_tril.LinearOperatorTriL): + if isinstance(operator, + linear_operator_lower_triangular.LinearOperatorLowerTriangular): return _TRIL if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): return _MATRIX diff --git a/tensorflow/contrib/losses/BUILD b/tensorflow/contrib/losses/BUILD index f75b0aa1b3e6606b0c92ae94b15b12781fe8b777..33fbbe12d3926606c468d13bef2842b81a857edb 100644 --- a/tensorflow/contrib/losses/BUILD +++ b/tensorflow/contrib/losses/BUILD @@ -15,6 +15,7 @@ py_library( "__init__.py", "python/losses/__init__.py", "python/losses/loss_ops.py", + "python/metric_learning/metric_loss_ops.py", ], srcs_version = "PY2AND3", deps = [ @@ -50,6 +51,49 @@ py_test( ], ) +py_library( + name = "metric_learning_py", + srcs = [ + "python/metric_learning/__init__.py", + "python/metric_learning/metric_loss_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:util", + ], +) + +py_test( + name = "metric_loss_ops_test", + srcs = [ + "python/metric_learning/metric_loss_ops_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":metric_learning_py", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/losses/__init__.py b/tensorflow/contrib/losses/__init__.py index 790bf61367d85b79bae4b153328b229b10721b38..db58647d48f0f6f093ef4b71d1e8a7b79e611184 100644 --- a/tensorflow/contrib/losses/__init__.py +++ b/tensorflow/contrib/losses/__init__.py @@ -22,10 +22,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.losses.python import metric_learning # pylint: disable=wildcard-import from tensorflow.contrib.losses.python.losses import * # pylint: enable=wildcard-import - from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ @@ -43,5 +43,6 @@ _allowed_symbols = [ 'sigmoid_cross_entropy', 'softmax_cross_entropy', 'sparse_softmax_cross_entropy', + 'metric_learning' ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 1d2477b8b794240bd348cec7f626be794181ffb4..7c523ad49265aaf32c8d5a8ae04d3e93262a1b55 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.deprecation import deprecated_args __all__ = ["absolute_difference", "add_loss", @@ -623,8 +624,9 @@ def mean_pairwise_squared_error( @deprecated("2016-12-30", "Use tf.losses.cosine_distance instead.") +@deprecated_args(None, "dim is deprecated, use axis instead", "dim") def cosine_distance( - predictions, labels=None, dim=None, weights=1.0, scope=None): + predictions, labels=None, axis=None, weights=1.0, scope=None, dim=None): """Adds a cosine-distance loss to the training procedure. Note that the function assumes that `predictions` and `labels` are already @@ -633,10 +635,11 @@ def cosine_distance( Args: predictions: An arbitrary matrix. labels: A `Tensor` whose shape matches 'predictions' - dim: The dimension along which the cosine distance is computed. + axis: The dimension along which the cosine distance is computed. weights: Coefficients for the loss a scalar, a tensor of shape [batch_size] or a tensor whose shape matches `predictions`. scope: The scope for the operations performed in computing the loss. + dim: The old (deprecated) name for `axis`. Returns: A scalar `Tensor` representing the loss value. @@ -645,8 +648,12 @@ def cosine_distance( ValueError: If `predictions` shape doesn't match `labels` shape, or `weights` is `None`. """ - if dim is None: - raise ValueError("`dim` cannot be None.") + if dim is not None: + if axis is not None: + raise ValueError("Cannot specify both 'axis' and 'dim'") + axis = dim + if axis is None and dim is None: + raise ValueError("You must specify 'axis'.") with ops.name_scope(scope, "cosine_distance_loss", [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -655,5 +662,5 @@ def cosine_distance( labels = math_ops.to_float(labels) radial_diffs = math_ops.multiply(predictions, labels) - losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,]) + losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[axis,]) return compute_weighted_loss(losses, weights, scope=scope) diff --git a/tensorflow/contrib/losses/python/metric_learning/__init__.py b/tensorflow/contrib/losses/python/metric_learning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e551d6acafb5c565965503075e8416e01c20a71 --- /dev/null +++ b/tensorflow/contrib/losses/python/metric_learning/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for building neural network losses. + +See @{$python/contrib.losses}. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.losses.python.metric_learning.metric_loss_ops import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'contrastive_loss', + 'cluster_loss', + 'lifted_struct_loss', + 'npairs_loss', + 'npairs_loss_multilabel', + 'triplet_semihard_loss', +] +remove_undocumented(__name__, _allowed_symbols) + + diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a57ba51bcf0a292490dfaa9e556f6e5811ed66 --- /dev/null +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py @@ -0,0 +1,1031 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements various metric learning losses.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.summary import summary +try: + # pylint: disable=g-import-not-at-top + from sklearn import metrics + HAS_SKLEARN = True +except ImportError: + HAS_SKLEARN = False + + +def pairwise_distance(feature, squared=False): + """Computes the pairwise distance matrix with numerical stability. + + output[i, j] = || feature[i, :] - feature[j, :] ||_2 + + Args: + feature: 2-D Tensor of size [number of data, feature dimension]. + squared: Boolean, whether or not to square the pairwise distances. + + Returns: + pairwise_distances: 2-D Tensor of size [number of data, number of data]. + """ + pairwise_distances_squared = math_ops.add( + math_ops.reduce_sum( + math_ops.square(feature), + axis=[1], + keep_dims=True), + math_ops.reduce_sum( + math_ops.square( + array_ops.transpose(feature)), + axis=[0], + keep_dims=True)) - 2.0 * math_ops.matmul( + feature, array_ops.transpose(feature)) + + # Deal with numerical inaccuracies. Set small negatives to zero. + pairwise_distances_squared = math_ops.maximum(pairwise_distances_squared, 0.0) + # Get the mask where the zero distances are at. + error_mask = math_ops.less_equal(pairwise_distances_squared, 0.0) + + # Optionally take the sqrt. + if squared: + pairwise_distances = pairwise_distances_squared + else: + pairwise_distances = math_ops.sqrt( + pairwise_distances_squared + math_ops.to_float(error_mask) * 1e-16) + + # Undo conditionally adding 1e-16. + pairwise_distances = math_ops.multiply( + pairwise_distances, math_ops.to_float(math_ops.logical_not(error_mask))) + + num_data = array_ops.shape(feature)[0] + # Explicitly set diagonals to zero. + mask_offdiagonals = array_ops.ones_like(pairwise_distances) - array_ops.diag( + array_ops.ones([num_data])) + pairwise_distances = math_ops.multiply(pairwise_distances, mask_offdiagonals) + return pairwise_distances + + +def contrastive_loss(labels, embeddings_anchor, embeddings_positive, + margin=1.0): + """Computes the contrastive loss. + + This loss encourages the embedding to be close to each other for + the samples of the same label and the embedding to be far apart at least + by the margin constant for the samples of different labels. + See: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf + + Args: + labels: 1-D tf.int32 `Tensor` with shape [batch_size] of + binary labels indicating positive vs negative pair. + embeddings_anchor: 2-D float `Tensor` of embedding vectors for the anchor + images. Embeddings should be l2 normalized. + embeddings_positive: 2-D float `Tensor` of embedding vectors for the + positive images. Embeddings should be l2 normalized. + margin: margin term in the loss definition. + + Returns: + contrastive_loss: tf.float32 scalar. + """ + # Get per pair distances + distances = math_ops.sqrt( + math_ops.reduce_sum( + math_ops.square(embeddings_anchor - embeddings_positive), 1)) + + # Add contrastive loss for the siamese network. + # label here is {0,1} for neg, pos. + return math_ops.reduce_mean( + math_ops.to_float(labels) * math_ops.square(distances) + + (1. - math_ops.to_float(labels)) * + math_ops.square(math_ops.maximum(margin - distances, 0.)), + name='contrastive_loss') + + +def masked_maximum(data, mask, dim=1): + """Computes the axis wise maximum over chosen elements. + + Args: + data: 2-D float `Tensor` of size [n, m]. + mask: 2-D Boolean `Tensor` of size [n, m]. + dim: The dimension over which to compute the maximum. + + Returns: + masked_maximums: N-D `Tensor`. + The maximized dimension is of size 1 after the operation. + """ + axis_minimums = math_ops.reduce_min(data, dim, keep_dims=True) + masked_maximums = math_ops.reduce_max( + math_ops.multiply( + data - axis_minimums, mask), dim, keep_dims=True) + axis_minimums + return masked_maximums + + +def masked_minimum(data, mask, dim=1): + """Computes the axis wise minimum over chosen elements. + + Args: + data: 2-D float `Tensor` of size [n, m]. + mask: 2-D Boolean `Tensor` of size [n, m]. + dim: The dimension over which to compute the minimum. + + Returns: + masked_minimums: N-D `Tensor`. + The minimized dimension is of size 1 after the operation. + """ + axis_maximums = math_ops.reduce_max(data, dim, keep_dims=True) + masked_minimums = math_ops.reduce_min( + math_ops.multiply( + data - axis_maximums, mask), dim, keep_dims=True) + axis_maximums + return masked_minimums + + +def triplet_semihard_loss(labels, embeddings, margin=1.0): + """Computes the triplet loss with semi-hard negative mining. + + The loss encourages the positive distances (between a pair of embeddings with + the same labels) to be smaller than the minimum negative distance among + which are at least greater than the positive distance plus the margin constant + (called semi-hard negative) in the mini-batch. If no such negative exists, + uses the largest negative distance instead. + See: https://arxiv.org/abs/1503.03832. + + Args: + labels: 1-D tf.int32 `Tensor` with shape [batch_size] of + multiclass integer labels. + embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should + be l2 normalized. + margin: Float, margin term in the loss definition. + + Returns: + triplet_loss: tf.float32 scalar. + """ + # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. + lshape = array_ops.shape(labels) + assert lshape.shape == 1 + labels = array_ops.reshape(labels, [lshape[0], 1]) + + # Build pairwise squared distance matrix. + pdist_matrix = pairwise_distance(embeddings, squared=True) + # Build pairwise binary adjacency matrix. + adjacency = math_ops.equal(labels, array_ops.transpose(labels)) + # Invert so we can select negatives only. + adjacency_not = math_ops.logical_not(adjacency) + + batch_size = array_ops.size(labels) + + # Compute the mask. + pdist_matrix_tile = array_ops.tile(pdist_matrix, [batch_size, 1]) + mask = math_ops.logical_and( + array_ops.tile(adjacency_not, [batch_size, 1]), + math_ops.greater( + pdist_matrix_tile, array_ops.reshape( + array_ops.transpose(pdist_matrix), [-1, 1]))) + mask_final = array_ops.reshape( + math_ops.greater( + math_ops.reduce_sum( + math_ops.cast( + mask, dtype=dtypes.float32), 1, keep_dims=True), + 0.0), [batch_size, batch_size]) + mask_final = array_ops.transpose(mask_final) + + adjacency_not = math_ops.cast(adjacency_not, dtype=dtypes.float32) + mask = math_ops.cast(mask, dtype=dtypes.float32) + + # negatives_outside: smallest D_an where D_an > D_ap. + negatives_outside = array_ops.reshape( + masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size]) + negatives_outside = array_ops.transpose(negatives_outside) + + # negatives_inside: largest D_an. + negatives_inside = array_ops.tile( + masked_maximum(pdist_matrix, adjacency_not), [1, batch_size]) + semi_hard_negatives = array_ops.where( + mask_final, negatives_outside, negatives_inside) + + loss_mat = math_ops.add(margin, pdist_matrix - semi_hard_negatives) + + mask_positives = math_ops.cast( + adjacency, dtype=dtypes.float32) - array_ops.diag( + array_ops.ones([batch_size])) + + # In lifted-struct, the authors multiply 0.5 for upper triangular + # in semihard, they take all positive pairs except the diagonal. + num_positives = math_ops.reduce_sum(mask_positives) + + triplet_loss = math_ops.truediv( + math_ops.reduce_sum( + math_ops.maximum( + math_ops.multiply(loss_mat, mask_positives), 0.0)), + num_positives, + name='triplet_semihard_loss') + + return triplet_loss + + +# pylint: disable=line-too-long +def npairs_loss(labels, embeddings_anchor, embeddings_positive, + reg_lambda=0.002, print_losses=False): + """Computes the npairs loss. + + Npairs loss expects paired data where a pair is composed of samples from the + same labels and each pairs in the minibatch have different labels. The loss + has two components. The first component is the L2 regularizer on the + embedding vectors. The second component is the sum of cross entropy loss + which takes each row of the pair-wise similarity matrix as logits and + the remapped one-hot labels as labels. + + See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf + + Args: + labels: 1-D tf.int32 `Tensor` of shape [batch_size/2]. + embeddings_anchor: 2-D Tensor of shape [batch_size/2, embedding_dim] for the + embedding vectors for the anchor images. Embeddings should not be + l2 normalized. + embeddings_positive: 2-D Tensor of shape [batch_size/2, embedding_dim] for the + embedding vectors for the positive images. Embeddings should not be + l2 normalized. + reg_lambda: Float. L2 regularization term on the embedding vectors. + print_losses: Boolean. Option to print the xent and l2loss. + + Returns: + npairs_loss: tf.float32 scalar. + """ + # pylint: enable=line-too-long + # Add the regularizer on the embedding. + reg_anchor = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_anchor), 1)) + reg_positive = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_positive), 1)) + l2loss = math_ops.multiply( + 0.25 * reg_lambda, reg_anchor + reg_positive, name='l2loss') + + # Get per pair similarities. + similarity_matrix = math_ops.matmul( + embeddings_anchor, embeddings_positive, transpose_a=False, + transpose_b=True) + + # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. + lshape = array_ops.shape(labels) + assert lshape.shape == 1 + labels = array_ops.reshape(labels, [lshape[0], 1]) + + labels_remapped = math_ops.to_float( + math_ops.equal(labels, array_ops.transpose(labels))) + labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True) + + # Add the softmax loss. + xent_loss = nn.softmax_cross_entropy_with_logits( + logits=similarity_matrix, labels=labels_remapped) + xent_loss = math_ops.reduce_mean(xent_loss, name='xentropy') + + if print_losses: + xent_loss = logging_ops.Print( + xent_loss, ['cross entropy:', xent_loss, 'l2loss:', l2loss]) + + return l2loss + xent_loss + + +def _build_multilabel_adjacency(sparse_labels): + """Builds multilabel adjacency matrix. + + As of March 14th, 2017, there's no op for the dot product between + two sparse tensors in TF. However, there is `sparse_minimum` op which is + equivalent to an AND op between two sparse boolean tensors. + This computes the dot product between two sparse boolean inputs. + + Args: + sparse_labels: List of 1-D boolean sparse tensors. + + Returns: + adjacency_matrix: 2-D dense `Tensor`. + """ + num_pairs = len(sparse_labels) + adjacency_matrix = array_ops.zeros([num_pairs, num_pairs]) + for i in range(num_pairs): + for j in range(num_pairs): + sparse_dot_product = math_ops.to_float( + sparse_ops.sparse_reduce_sum(sparse_ops.sparse_minimum( + sparse_labels[i], sparse_labels[j]))) + sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 0) + sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 1) + one_hot_matrix = array_ops.pad(sparse_dot_product, + [[i, num_pairs-i-1], + [j, num_pairs-j-1]], 'CONSTANT') + adjacency_matrix += one_hot_matrix + + return adjacency_matrix + + +def npairs_loss_multilabel(sparse_labels, embeddings_anchor, + embeddings_positive, reg_lambda=0.002, + print_losses=False): + r"""Computes the npairs loss with multilabel data. + + Npairs loss expects paired data where a pair is composed of samples from the + same labels and each pairs in the minibatch have different labels. The loss + has two components. The first component is the L2 regularizer on the + embedding vectors. The second component is the sum of cross entropy loss + which takes each row of the pair-wise similarity matrix as logits and + the remapped one-hot labels as labels. Here, the similarity is defined by the + dot product between two embedding vectors. S_{i,j} = f(x_i)^T f(x_j) + + To deal with multilabel inputs, we use the count of label intersection + i.e. L_{i,j} = | set_of_labels_for(i) \cap set_of_labels_for(j) | + Then we normalize each rows of the count based label matrix so that each row + sums to one. + + Args: + sparse_labels: List of 1-D Boolean `SparseTensor` of dense_shape + [batch_size/2, num_classes] labels for the anchor-pos pairs. + embeddings_anchor: 2-D `Tensor` of shape [batch_size/2, embedding_dim] for + the embedding vectors for the anchor images. Embeddings should not be + l2 normalized. + embeddings_positive: 2-D `Tensor` of shape [batch_size/2, embedding_dim] for + the embedding vectors for the positive images. Embeddings should not be + l2 normalized. + reg_lambda: Float. L2 regularization term on the embedding vectors. + print_losses: Boolean. Option to print the xent and l2loss. + + Returns: + npairs_loss: tf.float32 scalar. + Raises: + TypeError: When the specified sparse_labels is not a `SparseTensor`. + """ + if False in [isinstance( + l, sparse_tensor.SparseTensor) for l in sparse_labels]: + raise TypeError( + 'sparse_labels must be a list of SparseTensors, but got %s' % str( + sparse_labels)) + + with ops.name_scope('NpairsLossMultiLabel'): + # Add the regularizer on the embedding. + reg_anchor = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_anchor), 1)) + reg_positive = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_positive), 1)) + l2loss = math_ops.multiply(0.25 * reg_lambda, + reg_anchor + reg_positive, name='l2loss') + + # Get per pair similarities. + similarity_matrix = math_ops.matmul( + embeddings_anchor, embeddings_positive, transpose_a=False, + transpose_b=True) + + # TODO(coreylynch): need to check the sparse values + # TODO(coreylynch): are composed only of 0's and 1's. + + multilabel_adjacency_matrix = _build_multilabel_adjacency(sparse_labels) + labels_remapped = math_ops.to_float(multilabel_adjacency_matrix) + labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True) + + # Add the softmax loss. + xent_loss = nn.softmax_cross_entropy_with_logits( + logits=similarity_matrix, labels=labels_remapped) + xent_loss = math_ops.reduce_mean(xent_loss, name='xentropy') + + if print_losses: + xent_loss = logging_ops.Print( + xent_loss, ['cross entropy:', xent_loss, 'l2loss:', l2loss]) + + return l2loss + xent_loss + + +def lifted_struct_loss(labels, embeddings, margin=1.0): + """Computes the lifted structured loss. + + The loss encourages the positive distances (between a pair of embeddings + with the same labels) to be smaller than any negative distances (between a + pair of embeddings with different labels) in the mini-batch in a way + that is differentiable with respect to the embedding vectors. + See: https://arxiv.org/abs/1511.06452. + + Args: + labels: 1-D tf.int32 `Tensor` with shape [batch_size] of + multiclass integer labels. + embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should not + be l2 normalized. + margin: Float, margin term in the loss definition. + + Returns: + lifted_loss: tf.float32 scalar. + """ + # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. + lshape = array_ops.shape(labels) + assert lshape.shape == 1 + labels = array_ops.reshape(labels, [lshape[0], 1]) + + # Build pairwise squared distance matrix. + pairwise_distances = pairwise_distance(embeddings) + + # Build pairwise binary adjacency matrix. + adjacency = math_ops.equal(labels, array_ops.transpose(labels)) + # Invert so we can select negatives only. + adjacency_not = math_ops.logical_not(adjacency) + + batch_size = array_ops.size(labels) + + diff = margin - pairwise_distances + mask = math_ops.cast(adjacency_not, dtype=dtypes.float32) + # Safe maximum: Temporarily shift negative distances + # above zero before taking max. + # this is to take the max only among negatives. + row_minimums = math_ops.reduce_min(diff, 1, keep_dims=True) + row_negative_maximums = math_ops.reduce_max( + math_ops.multiply( + diff - row_minimums, mask), 1, keep_dims=True) + row_minimums + + # Compute the loss. + # Keep track of matrix of maximums where M_ij = max(m_i, m_j) + # where m_i is the max of alpha - negative D_i's. + # This matches the Caffe loss layer implementation at: + # https://github.com/rksltnl/Caffe-Deep-Metric-Learning-CVPR16/blob/0efd7544a9846f58df923c8b992198ba5c355454/src/caffe/layers/lifted_struct_similarity_softmax_layer.cpp # pylint: disable=line-too-long + + max_elements = math_ops.maximum( + row_negative_maximums, array_ops.transpose(row_negative_maximums)) + diff_tiled = array_ops.tile(diff, [batch_size, 1]) + mask_tiled = array_ops.tile(mask, [batch_size, 1]) + max_elements_vect = array_ops.reshape( + array_ops.transpose(max_elements), [-1, 1]) + + loss_exp_left = array_ops.reshape( + math_ops.reduce_sum(math_ops.multiply( + math_ops.exp( + diff_tiled - max_elements_vect), + mask_tiled), 1, keep_dims=True), [batch_size, batch_size]) + + loss_mat = max_elements + math_ops.log( + loss_exp_left + array_ops.transpose(loss_exp_left)) + # Add the positive distance. + loss_mat += pairwise_distances + + mask_positives = math_ops.cast( + adjacency, dtype=dtypes.float32) - array_ops.diag( + array_ops.ones([batch_size])) + + # *0.5 for upper triangular, and another *0.5 for 1/2 factor for loss^2. + num_positives = math_ops.reduce_sum(mask_positives) / 2.0 + + lifted_loss = math_ops.truediv( + 0.25 * math_ops.reduce_sum( + math_ops.square( + math_ops.maximum( + math_ops.multiply(loss_mat, mask_positives), 0.0))), + num_positives, + name='liftedstruct_loss') + return lifted_loss + + +def update_1d_tensor(y, index, value): + """Updates 1d tensor y so that y[index] = value. + + Args: + y: 1-D Tensor. + index: index of y to modify. + value: new value to write at y[index]. + + Returns: + y_mod: 1-D Tensor. Tensor y after the update. + """ + value = array_ops.squeeze(value) + # modify the 1D tensor x at index with value. + # ex) chosen_ids = update_1D_tensor(chosen_ids, cluster_idx, best_medoid) + y_before = array_ops.slice(y, [0], [index]) + y_after = array_ops.slice(y, [index + 1], [-1]) + y_mod = array_ops.concat([y_before, [value], y_after], 0) + return y_mod + + +def get_cluster_assignment(pairwise_distances, centroid_ids): + """Assign data points to the neareset centroids. + + Tensorflow has numerical instability and doesn't always choose + the data point with theoretically zero distance as it's nearest neighbor. + Thus, for each centroid in centroid_ids, explicitly assign + the centroid itself as the nearest centroid. + This is done through the mask tensor and the constraint_vect tensor. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + centroid_ids: 1-D Tensor of centroid indices. + + Returns: + y_fixed: 1-D tensor of cluster assignment. + """ + predictions = math_ops.argmin( + array_ops.gather(pairwise_distances, centroid_ids), dimension=0) + batch_size = array_ops.shape(pairwise_distances)[0] + + # Deal with numerical instability + mask = math_ops.reduce_any(array_ops.one_hot( + centroid_ids, batch_size, True, False, axis=-1, dtype=dtypes.bool), + axis=0) + constraint_one_hot = math_ops.multiply( + array_ops.one_hot(centroid_ids, + batch_size, + array_ops.constant(1, dtype=dtypes.int64), + array_ops.constant(0, dtype=dtypes.int64), + axis=0, + dtype=dtypes.int64), + math_ops.to_int64(math_ops.range(array_ops.shape(centroid_ids)[0]))) + constraint_vect = math_ops.reduce_sum( + array_ops.transpose(constraint_one_hot), axis=0) + + y_fixed = array_ops.where(mask, constraint_vect, predictions) + return y_fixed + + +def compute_facility_energy(pairwise_distances, centroid_ids): + """Compute the average travel distance to the assigned centroid. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + centroid_ids: 1-D Tensor of indices. + + Returns: + facility_energy: dtypes.float32 scalar. + """ + return -1.0 * math_ops.reduce_sum( + math_ops.reduce_min( + array_ops.gather(pairwise_distances, centroid_ids), axis=0)) + + +def compute_clustering_score(labels, predictions, margin_type): + """Computes the clustering score via sklearn.metrics functions. + + There are various ways to compute the clustering score. Intuitively, + we want to measure the agreement of two clustering assignments (labels vs + predictions) ignoring the permutations and output a score from zero to one. + (where the values close to one indicate significant agreement). + This code supports following scoring functions: + nmi: normalized mutual information + ami: adjusted mutual information + ari: adjusted random index + vmeasure: v-measure + const: indicator checking whether the two clusterings are the same. + See http://scikit-learn.org/stable/modules/classes.html#clustering-metrics + for the detailed descriptions. + Args: + labels: 1-D Tensor. ground truth cluster assignment. + predictions: 1-D Tensor. predicted cluster assignment. + margin_type: Type of structured margin to use. Default is nmi. + Returns: + clustering_score: dtypes.float32 scalar. + The possible valid values are from zero to one. + Zero means the worst clustering and one means the perfect clustering. + Raises: + ValueError: margin_type is not recognized. + """ + margin_type_to_func = { + 'nmi': _compute_nmi_score, + 'ami': _compute_ami_score, + 'ari': _compute_ari_score, + 'vmeasure': _compute_vmeasure_score, + 'const': _compute_zeroone_score + } + + if margin_type not in margin_type_to_func: + raise ValueError('Unrecognized margin_type: %s' % margin_type) + clustering_score_fn = margin_type_to_func[margin_type] + return array_ops.squeeze(clustering_score_fn(labels, predictions)) + + +def _compute_nmi_score(labels, predictions): + return math_ops.to_float( + script_ops.py_func( + metrics.normalized_mutual_info_score, [labels, predictions], + [dtypes.float64], + name='nmi')) + + +def _compute_ami_score(labels, predictions): + ami_score = math_ops.to_float( + script_ops.py_func( + metrics.adjusted_mutual_info_score, [labels, predictions], + [dtypes.float64], + name='ami')) + return math_ops.maximum(0.0, ami_score) + + +def _compute_ari_score(labels, predictions): + ari_score = math_ops.to_float( + script_ops.py_func( + metrics.adjusted_rand_score, [labels, predictions], [dtypes.float64], + name='ari')) + # ari score can go below 0 + # http://scikit-learn.org/stable/modules/clustering.html#adjusted-rand-score + return math_ops.maximum(0.0, ari_score) + + +def _compute_vmeasure_score(labels, predictions): + vmeasure_score = math_ops.to_float( + script_ops.py_func( + metrics.v_measure_score, [labels, predictions], [dtypes.float64], + name='vmeasure')) + return math_ops.maximum(0.0, vmeasure_score) + + +def _compute_zeroone_score(labels, predictions): + zeroone_score = math_ops.to_float( + math_ops.equal( + math_ops.reduce_sum( + math_ops.to_int32(math_ops.equal(labels, predictions))), + array_ops.shape(labels)[0])) + return zeroone_score + + +def _find_loss_augmented_facility_idx(pairwise_distances, labels, chosen_ids, + candidate_ids, margin_multiplier, + margin_type): + """Find the next centroid that maximizes the loss augmented inference. + + This function is a subroutine called from compute_augmented_facility_locations + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + chosen_ids: 1-D Tensor of current centroid indices. + candidate_ids: 1-D Tensor of candidate indices. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + integer index. + """ + num_candidates = array_ops.shape(candidate_ids)[0] + + pairwise_distances_chosen = array_ops.gather(pairwise_distances, chosen_ids) + pairwise_distances_candidate = array_ops.gather( + pairwise_distances, candidate_ids) + pairwise_distances_chosen_tile = array_ops.tile( + pairwise_distances_chosen, [1, num_candidates]) + + candidate_scores = -1.0 * math_ops.reduce_sum( + array_ops.reshape( + math_ops.reduce_min( + array_ops.concat([ + pairwise_distances_chosen_tile, + array_ops.reshape(pairwise_distances_candidate, [1, -1]) + ], 0), + axis=0, + keep_dims=True), [num_candidates, -1]), + axis=1) + + nmi_scores = array_ops.zeros([num_candidates]) + iteration = array_ops.constant(0) + + def func_cond(iteration, nmi_scores): + del nmi_scores # Unused in func_cond() + return iteration < num_candidates + + def func_body(iteration, nmi_scores): + predictions = get_cluster_assignment( + pairwise_distances, + array_ops.concat([chosen_ids, [candidate_ids[iteration]]], 0)) + nmi_score_i = compute_clustering_score(labels, predictions, margin_type) + pad_before = array_ops.zeros([iteration]) + pad_after = array_ops.zeros([num_candidates - 1 - iteration]) + # return 1 - NMI score as the structured loss. + # because NMI is higher the better [0,1]. + return iteration + 1, nmi_scores + array_ops.concat( + [pad_before, [1.0 - nmi_score_i], pad_after], 0) + + _, nmi_scores = control_flow_ops.while_loop( + func_cond, func_body, [iteration, nmi_scores]) + + candidate_scores = math_ops.add( + candidate_scores, margin_multiplier * nmi_scores) + + argmax_index = math_ops.to_int32( + math_ops.argmax(candidate_scores, dimension=0)) + + return candidate_ids[argmax_index] + + +def compute_augmented_facility_locations(pairwise_distances, labels, all_ids, + margin_multiplier, margin_type): + """Computes the centroid locations. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + all_ids: 1-D Tensor of all data indices. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + chosen_ids: 1-D Tensor of chosen centroid indices. + """ + + def func_cond_augmented(iteration, chosen_ids): + del chosen_ids # Unused argument in func_cond_augmented. + return iteration < num_classes + + def func_body_augmented(iteration, chosen_ids): + # find a new facility location to add + # based on the clustering score and the NMI score + candidate_ids = array_ops.setdiff1d(all_ids, chosen_ids)[0] + new_chosen_idx = _find_loss_augmented_facility_idx(pairwise_distances, + labels, chosen_ids, + candidate_ids, + margin_multiplier, + margin_type) + chosen_ids = array_ops.concat([chosen_ids, [new_chosen_idx]], 0) + return iteration + 1, chosen_ids + + num_classes = array_ops.size(array_ops.unique(labels)[0]) + chosen_ids = array_ops.constant(0, dtype=dtypes.int32, shape=[0]) + + # num_classes get determined at run time based on the sampled batch. + iteration = array_ops.constant(0) + + _, chosen_ids = control_flow_ops.while_loop( + func_cond_augmented, + func_body_augmented, [iteration, chosen_ids], + shape_invariants=[iteration.get_shape(), tensor_shape.TensorShape( + [None])]) + return chosen_ids + + +def update_medoid_per_cluster(pairwise_distances, pairwise_distances_subset, + labels, chosen_ids, cluster_member_ids, + cluster_idx, margin_multiplier, margin_type): + """Updates the cluster medoid per cluster. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + pairwise_distances_subset: 2-D Tensor of pairwise distances for one cluster. + labels: 1-D Tensor of ground truth cluster assignment. + chosen_ids: 1-D Tensor of cluster centroid indices. + cluster_member_ids: 1-D Tensor of cluster member indices for one cluster. + cluster_idx: Index of this one cluster. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + chosen_ids: Updated 1-D Tensor of cluster centroid indices. + """ + + def func_cond(iteration, scores_margin): + del scores_margin # Unused variable scores_margin. + return iteration < num_candidates + + def func_body(iteration, scores_margin): + # swap the current medoid with the candidate cluster member + candidate_medoid = math_ops.to_int32(cluster_member_ids[iteration]) + tmp_chosen_ids = update_1d_tensor(chosen_ids, cluster_idx, candidate_medoid) + predictions = get_cluster_assignment(pairwise_distances, tmp_chosen_ids) + metric_score = compute_clustering_score(labels, predictions, margin_type) + pad_before = array_ops.zeros([iteration]) + pad_after = array_ops.zeros([num_candidates - 1 - iteration]) + return iteration + 1, scores_margin + array_ops.concat( + [pad_before, [1.0 - metric_score], pad_after], 0) + + # pairwise_distances_subset is of size [p, 1, 1, p], + # the intermediate dummy dimensions at + # [1, 2] makes this code work in the edge case where p=1. + # this happens if the cluster size is one. + scores_fac = -1.0 * math_ops.reduce_sum( + array_ops.squeeze(pairwise_distances_subset, [1, 2]), axis=0) + + iteration = array_ops.constant(0) + num_candidates = array_ops.size(cluster_member_ids) + scores_margin = array_ops.zeros([num_candidates]) + + _, scores_margin = control_flow_ops.while_loop(func_cond, func_body, + [iteration, scores_margin]) + candidate_scores = math_ops.add(scores_fac, margin_multiplier * scores_margin) + + argmax_index = math_ops.to_int32( + math_ops.argmax(candidate_scores, dimension=0)) + + best_medoid = math_ops.to_int32(cluster_member_ids[argmax_index]) + chosen_ids = update_1d_tensor(chosen_ids, cluster_idx, best_medoid) + return chosen_ids + + +def update_all_medoids(pairwise_distances, predictions, labels, chosen_ids, + margin_multiplier, margin_type): + """Updates all cluster medoids a cluster at a time. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + predictions: 1-D Tensor of predicted cluster assignment. + labels: 1-D Tensor of ground truth cluster assignment. + chosen_ids: 1-D Tensor of cluster centroid indices. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + chosen_ids: Updated 1-D Tensor of cluster centroid indices. + """ + + def func_cond_augmented_pam(iteration, chosen_ids): + del chosen_ids # Unused argument. + return iteration < num_classes + + def func_body_augmented_pam(iteration, chosen_ids): + """Call the update_medoid_per_cluster subroutine.""" + mask = math_ops.equal( + math_ops.to_int64(predictions), math_ops.to_int64(iteration)) + this_cluster_ids = array_ops.where(mask) + + pairwise_distances_subset = array_ops.transpose( + array_ops.gather( + array_ops.transpose( + array_ops.gather(pairwise_distances, this_cluster_ids)), + this_cluster_ids)) + + chosen_ids = update_medoid_per_cluster(pairwise_distances, + pairwise_distances_subset, labels, + chosen_ids, this_cluster_ids, + iteration, margin_multiplier, + margin_type) + return iteration + 1, chosen_ids + + unique_class_ids = array_ops.unique(labels)[0] + num_classes = array_ops.size(unique_class_ids) + iteration = array_ops.constant(0) + + _, chosen_ids = control_flow_ops.while_loop( + func_cond_augmented_pam, func_body_augmented_pam, [iteration, chosen_ids]) + return chosen_ids + + +def compute_augmented_facility_locations_pam(pairwise_distances, + labels, + margin_multiplier, + margin_type, + chosen_ids, + pam_max_iter=5): + """Refine the cluster centroids with PAM local search. + + For fixed iterations, alternate between updating the cluster assignment + and updating cluster medoids. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + chosen_ids: 1-D Tensor of initial estimate of cluster centroids. + pam_max_iter: Number of refinement iterations. + + Returns: + chosen_ids: Updated 1-D Tensor of cluster centroid indices. + """ + for _ in range(pam_max_iter): + # update the cluster assignment given the chosen_ids (S_pred) + predictions = get_cluster_assignment(pairwise_distances, chosen_ids) + + # update the medoids per each cluster + chosen_ids = update_all_medoids(pairwise_distances, predictions, labels, + chosen_ids, margin_multiplier, margin_type) + + return chosen_ids + + +def compute_gt_cluster_score(pairwise_distances, labels): + """Compute ground truth facility location score. + + Loop over each unique classes and compute average travel distances. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + + Returns: + gt_cluster_score: dtypes.float32 score. + """ + unique_class_ids = array_ops.unique(labels)[0] + num_classes = array_ops.size(unique_class_ids) + iteration = array_ops.constant(0) + gt_cluster_score = array_ops.constant(0.0, dtype=dtypes.float32) + + def func_cond(iteration, gt_cluster_score): + del gt_cluster_score # Unused argument. + return iteration < num_classes + + def func_body(iteration, gt_cluster_score): + """Per each cluster, compute the average travel distance.""" + mask = math_ops.equal(labels, unique_class_ids[iteration]) + this_cluster_ids = array_ops.where(mask) + pairwise_distances_subset = array_ops.transpose( + array_ops.gather( + array_ops.transpose( + array_ops.gather(pairwise_distances, this_cluster_ids)), + this_cluster_ids)) + this_cluster_score = -1.0 * math_ops.reduce_min( + math_ops.reduce_sum( + pairwise_distances_subset, axis=0)) + return iteration + 1, gt_cluster_score + this_cluster_score + + _, gt_cluster_score = control_flow_ops.while_loop( + func_cond, func_body, [iteration, gt_cluster_score]) + return gt_cluster_score + + +def cluster_loss(labels, + embeddings, + margin_multiplier, + enable_pam_finetuning=True, + margin_type='nmi', + print_losses=False): + """Computes the clustering loss. + + The following structured margins are supported: + nmi: normalized mutual information + ami: adjusted mutual information + ari: adjusted random index + vmeasure: v-measure + const: indicator checking whether the two clusterings are the same. + + Args: + labels: 2-D Tensor of labels of shape [batch size, 1] + embeddings: 2-D Tensor of embeddings of shape + [batch size, embedding dimension]. Embeddings should be l2 normalized. + margin_multiplier: float32 scalar. multiplier on the structured margin term + See section 3.2 of paper for discussion. + enable_pam_finetuning: Boolean, Whether to run local pam refinement. + See section 3.4 of paper for discussion. + margin_type: Type of structured margin to use. See section 3.2 of + paper for discussion. Can be 'nmi', 'ami', 'ari', 'vmeasure', 'const'. + print_losses: Boolean. Option to print the loss. + + Paper: https://arxiv.org/abs/1612.01213. + + Returns: + clustering_loss: A float32 scalar `Tensor`. + Raises: + ImportError: If sklearn dependency is not installed. + """ + if not HAS_SKLEARN: + raise ImportError('Cluster loss depends on sklearn.') + pairwise_distances = pairwise_distance(embeddings) + labels = array_ops.squeeze(labels) + all_ids = math_ops.range(array_ops.shape(embeddings)[0]) + + # Compute the loss augmented inference and get the cluster centroids. + chosen_ids = compute_augmented_facility_locations(pairwise_distances, labels, + all_ids, margin_multiplier, + margin_type) + # Given the predicted centroids, compute the clustering score. + score_pred = compute_facility_energy(pairwise_distances, chosen_ids) + + # Branch whether to use PAM finetuning. + if enable_pam_finetuning: + # Initialize with augmented facility solution. + chosen_ids = compute_augmented_facility_locations_pam(pairwise_distances, + labels, + margin_multiplier, + margin_type, + chosen_ids) + score_pred = compute_facility_energy(pairwise_distances, chosen_ids) + + # Given the predicted centroids, compute the cluster assignments. + predictions = get_cluster_assignment(pairwise_distances, chosen_ids) + + # Compute the clustering (i.e. NMI) score between the two assignments. + clustering_score_pred = compute_clustering_score(labels, predictions, + margin_type) + + # Compute the clustering score from labels. + score_gt = compute_gt_cluster_score(pairwise_distances, labels) + + # Compute the hinge loss. + clustering_loss = math_ops.maximum( + score_pred + margin_multiplier * (1.0 - clustering_score_pred) - score_gt, + 0.0, + name='clustering_loss') + clustering_loss.set_shape([]) + + if print_losses: + clustering_loss = logging_ops.Print( + clustering_loss, + ['clustering_loss: ', clustering_loss, array_ops.shape( + clustering_loss)]) + + # Clustering specific summary. + summary.scalar('losses/score_pred', score_pred) + summary.scalar('losses/' + margin_type, clustering_score_pred) + summary.scalar('losses/score_gt', score_gt) + + return clustering_loss diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ec539ab42b4e0ba90a2a1f379a1d4d4b49d11f3 --- /dev/null +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py @@ -0,0 +1,562 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for triplet_semihard_loss.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.losses.python import metric_learning as metric_loss_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test +try: + # pylint: disable=g-import-not-at-top + from sklearn import datasets + from sklearn import metrics + HAS_SKLEARN = True +except ImportError: + HAS_SKLEARN = False + + +def pairwise_distance_np(feature, squared=False): + """Computes the pairwise distance matrix in numpy. + + Args: + feature: 2-D numpy array of size [number of data, feature dimension] + squared: Boolean. If true, output is the pairwise squared euclidean + distance matrix; else, output is the pairwise euclidean distance matrix. + + Returns: + pairwise_distances: 2-D numpy array of size + [number of data, number of data]. + """ + triu = np.triu_indices(feature.shape[0], 1) + upper_tri_pdists = np.linalg.norm(feature[triu[1]] - feature[triu[0]], axis=1) + if squared: + upper_tri_pdists **= 2. + num_data = feature.shape[0] + pairwise_distances = np.zeros((num_data, num_data)) + pairwise_distances[np.triu_indices(num_data, 1)] = upper_tri_pdists + # Make symmetrical. + pairwise_distances = pairwise_distances + pairwise_distances.T - np.diag( + pairwise_distances.diagonal()) + return pairwise_distances + + +class ContrastiveLossTest(test.TestCase): + + def testContrastive(self): + with self.test_session(): + num_data = 10 + feat_dim = 6 + margin = 1.0 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + labels = np.random.randint(0, 2, size=(num_data,)).astype(np.float32) + + # Compute the loss in NP + dist = np.sqrt( + np.sum(np.square(embeddings_anchor - embeddings_positive), axis=1)) + loss_np = np.mean( + labels * np.square(dist) + + (1.0 - labels) * np.square(np.maximum(margin - dist, 0.0))) + # Compute the loss with TF + loss_tf = metric_loss_ops.contrastive_loss( + labels=ops.convert_to_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + margin=margin) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +class TripletSemiHardLossTest(test.TestCase): + + def testTripletSemiHard(self): + with self.test_session(): + num_data = 10 + feat_dim = 6 + margin = 1.0 + num_classes = 4 + + embedding = np.random.rand(num_data, feat_dim).astype(np.float32) + labels = np.random.randint( + 0, num_classes, size=(num_data)).astype(np.float32) + + # Reshape labels to compute adjacency matrix. + labels_reshaped = np.reshape(labels, (labels.shape[0], 1)) + # Compute the loss in NP. + adjacency = np.equal(labels_reshaped, labels_reshaped.T) + + pdist_matrix = pairwise_distance_np(embedding, squared=True) + loss_np = 0.0 + num_positives = 0.0 + for i in range(num_data): + for j in range(num_data): + if adjacency[i][j] > 0.0 and i != j: + num_positives += 1.0 + + pos_distance = pdist_matrix[i][j] + neg_distances = [] + + for k in range(num_data): + if adjacency[i][k] == 0: + neg_distances.append(pdist_matrix[i][k]) + + # Sort by distance. + neg_distances.sort() + chosen_neg_distance = neg_distances[0] + + for l in range(len(neg_distances)): + chosen_neg_distance = neg_distances[l] + if chosen_neg_distance > pos_distance: + break + + loss_np += np.maximum( + 0.0, margin - chosen_neg_distance + pos_distance) + + loss_np /= num_positives + + # Compute the loss in TF. + loss_tf = metric_loss_ops.triplet_semihard_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embedding), + margin=margin) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +class LiftedStructLossTest(test.TestCase): + + def testLiftedStruct(self): + with self.test_session(): + num_data = 10 + feat_dim = 6 + margin = 1.0 + num_classes = 4 + + embedding = np.random.rand(num_data, feat_dim).astype(np.float32) + labels = np.random.randint( + 0, num_classes, size=(num_data)).astype(np.float32) + # Reshape labels to compute adjacency matrix. + labels_reshaped = np.reshape(labels, (labels.shape[0], 1)) + + # Compute the loss in NP + adjacency = np.equal(labels_reshaped, labels_reshaped.T) + pdist_matrix = pairwise_distance_np(embedding) + loss_np = 0.0 + num_constraints = 0.0 + for i in range(num_data): + for j in range(num_data): + if adjacency[i][j] > 0.0 and i != j: + d_pos = pdist_matrix[i][j] + negs = [] + for k in range(num_data): + if not adjacency[i][k]: + negs.append(margin - pdist_matrix[i][k]) + for l in range(num_data): + if not adjacency[j][l]: + negs.append(margin - pdist_matrix[j][l]) + + negs = np.array(negs) + max_elem = np.max(negs) + negs -= max_elem + negs = np.exp(negs) + soft_maximum = np.log(np.sum(negs)) + max_elem + + num_constraints += 1.0 + this_loss = max(soft_maximum + d_pos, 0) + loss_np += this_loss * this_loss + + loss_np = loss_np / num_constraints / 2.0 + + # Compute the loss in TF + loss_tf = metric_loss_ops.lifted_struct_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embedding), + margin=margin) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +def convert_to_list_of_sparse_tensor(np_matrix): + list_of_sparse_tensors = [] + nrows, ncols = np_matrix.shape + for i in range(nrows): + sp_indices = [] + for j in range(ncols): + if np_matrix[i][j] == 1: + sp_indices.append([j]) + + num_non_zeros = len(sp_indices) + list_of_sparse_tensors.append(sparse_tensor.SparseTensor( + indices=np.array(sp_indices), + values=np.ones((num_non_zeros,)), + dense_shape=np.array([ncols,]))) + + return list_of_sparse_tensors + + +class NpairsLossTest(test.TestCase): + + def testNpairs(self): + with self.test_session(): + num_data = 15 + feat_dim = 6 + num_classes = 5 + reg_lambda = 0.02 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + + labels = np.random.randint( + 0, num_classes, size=(num_data)).astype(np.float32) + # Reshape labels to compute adjacency matrix. + labels_reshaped = np.reshape(labels, (labels.shape[0], 1)) + + # Compute the loss in NP + reg_term = np.mean(np.sum(np.square(embeddings_anchor), 1)) + reg_term += np.mean(np.sum(np.square(embeddings_positive), 1)) + reg_term *= 0.25 * reg_lambda + + similarity_matrix = np.matmul(embeddings_anchor, embeddings_positive.T) + + labels_remapped = np.equal( + labels_reshaped, labels_reshaped.T).astype(np.float32) + labels_remapped /= np.sum(labels_remapped, axis=1, keepdims=True) + + xent_loss = math_ops.reduce_mean(nn.softmax_cross_entropy_with_logits( + logits=ops.convert_to_tensor(similarity_matrix), + labels=ops.convert_to_tensor(labels_remapped))).eval() + loss_np = xent_loss + reg_term + + # Compute the loss in TF + loss_tf = metric_loss_ops.npairs_loss( + labels=ops.convert_to_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +class NpairsLossMultiLabelTest(test.TestCase): + + def testNpairsMultiLabelLossWithSingleLabelEqualsNpairsLoss(self): + with self.test_session(): + num_data = 15 + feat_dim = 6 + reg_lambda = 0.02 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + labels = np.arange(num_data) + labels = np.reshape(labels, -1) + + # Compute vanila npairs loss. + loss_npairs = metric_loss_ops.npairs_loss( + labels=ops.convert_to_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda).eval() + + # Compute npairs multilabel loss. + labels_one_hot = np.identity(num_data) + loss_npairs_multilabel = metric_loss_ops.npairs_loss_multilabel( + sparse_labels=convert_to_list_of_sparse_tensor(labels_one_hot), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda).eval() + + self.assertAllClose(loss_npairs, loss_npairs_multilabel) + + def testNpairsMultiLabel(self): + with self.test_session(): + num_data = 15 + feat_dim = 6 + num_classes = 10 + reg_lambda = 0.02 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + + labels = np.random.randint(0, 2, (num_data, num_classes)) + # set entire column to one so that each row has at least one bit set. + labels[:, -1] = 1 + + # Compute the loss in NP + reg_term = np.mean(np.sum(np.square(embeddings_anchor), 1)) + reg_term += np.mean(np.sum(np.square(embeddings_positive), 1)) + reg_term *= 0.25 * reg_lambda + + similarity_matrix = np.matmul(embeddings_anchor, embeddings_positive.T) + + labels_remapped = np.dot(labels, labels.T).astype(np.float) + labels_remapped /= np.sum(labels_remapped, 1, keepdims=True) + + xent_loss = math_ops.reduce_mean(nn.softmax_cross_entropy_with_logits( + logits=ops.convert_to_tensor(similarity_matrix), + labels=ops.convert_to_tensor(labels_remapped))).eval() + loss_np = xent_loss + reg_term + + # Compute the loss in TF + loss_tf = metric_loss_ops.npairs_loss_multilabel( + sparse_labels=convert_to_list_of_sparse_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda) + loss_tf = loss_tf.eval() + + self.assertAllClose(loss_np, loss_tf) + + +def compute_ground_truth_cluster_score(feat, y): + y_unique = np.unique(y) + score_gt_np = 0.0 + for c in y_unique: + feat_subset = feat[y == c, :] + pdist_subset = pairwise_distance_np(feat_subset) + score_gt_np += -1.0 * np.min(np.sum(pdist_subset, axis=0)) + score_gt_np = score_gt_np.astype(np.float32) + return score_gt_np + + +def compute_cluster_loss_numpy(feat, + y, + margin_multiplier=1.0, + enable_pam_finetuning=True): + if enable_pam_finetuning: + facility = ForwardGreedyFacility( + n_clusters=np.unique(y).size).pam_augmented_fit(feat, y, + margin_multiplier) + else: + facility = ForwardGreedyFacility( + n_clusters=np.unique(y).size).loss_augmented_fit(feat, y, + margin_multiplier) + + score_augmented = facility.score_aug_ + score_gt = compute_ground_truth_cluster_score(feat, y) + return np.maximum(np.float32(0.0), score_augmented - score_gt) + + +class ForwardGreedyFacility(object): + + def __init__(self, n_clusters=8): + self.n_clusters = n_clusters + self.center_ics_ = None + + def _check_init_args(self): + # Check n_clusters. + if (self.n_clusters is None or self.n_clusters <= 0 or + not isinstance(self.n_clusters, int)): + raise ValueError('n_clusters has to be nonnegative integer.') + + def loss_augmented_fit(self, feat, y, loss_mult): + """Fit K-Medoids to the provided data.""" + self._check_init_args() + # Check that the array is good and attempt to convert it to + # Numpy array if possible. + feat = self._check_array(feat) + # Apply distance metric to get the distance matrix. + pdists = pairwise_distance_np(feat) + + num_data = feat.shape[0] + candidate_ids = list(range(num_data)) + candidate_scores = np.zeros(num_data,) + subset = [] + + k = 0 + while k < self.n_clusters: + candidate_scores = [] + for i in candidate_ids: + # push i to subset. + subset.append(i) + marginal_cost = -1.0 * np.sum(np.min(pdists[:, subset], axis=1)) + loss = 1.0 - metrics.normalized_mutual_info_score( + y, self._get_cluster_ics(pdists, subset)) + candidate_scores.append(marginal_cost + loss_mult * loss) + # remove i from subset. + subset.pop() + + # push i_star to subset. + i_star = candidate_ids[np.argmax(candidate_scores)] + subset.append(i_star) + # remove i_star from candidate indices. + candidate_ids.remove(i_star) + k += 1 + + # Expose labels_ which are the assignments of + # the training data to clusters. + self.labels_ = self._get_cluster_ics(pdists, subset) + # Expose cluster centers, i.e. medoids. + self.cluster_centers_ = feat.take(subset, axis=0) + # Expose indices of chosen cluster centers. + self.center_ics_ = subset + # Expose the score = -\sum_{i \in V} min_{j \in S} || x_i - x_j || + self.score_ = np.float32(-1.0) * self._get_facility_distance(pdists, subset) + self.score_aug_ = self.score_ + loss_mult * ( + 1.0 - metrics.normalized_mutual_info_score( + y, self._get_cluster_ics(pdists, subset))) + self.score_aug_ = self.score_aug_.astype(np.float32) + # Expose the chosen cluster indices. + self.subset_ = subset + return self + + def _augmented_update_medoid_ics_in_place(self, pdists, y_gt, cluster_ics, + medoid_ics, loss_mult): + for cluster_idx in range(self.n_clusters): + # y_pred = self._get_cluster_ics(D, medoid_ics) + # Don't prematurely do the assignment step. + # Do this after we've updated all cluster medoids. + y_pred = cluster_ics + + if sum(y_pred == cluster_idx) == 0: + # Cluster is empty. + continue + + curr_score = ( + -1.0 * np.sum( + pdists[medoid_ics[cluster_idx], y_pred == cluster_idx]) + + loss_mult * (1.0 - metrics.normalized_mutual_info_score( + y_gt, y_pred))) + + pdist_in = pdists[y_pred == cluster_idx, :] + pdist_in = pdist_in[:, y_pred == cluster_idx] + + all_scores_fac = np.sum(-1.0 * pdist_in, axis=1) + all_scores_loss = [] + for i in range(y_pred.size): + if y_pred[i] != cluster_idx: + continue + # remove this cluster's current centroid + medoid_ics_i = medoid_ics[:cluster_idx] + medoid_ics[cluster_idx + 1:] + # add this new candidate to the centroid list + medoid_ics_i += [i] + y_pred_i = self._get_cluster_ics(pdists, medoid_ics_i) + all_scores_loss.append(loss_mult * ( + 1.0 - metrics.normalized_mutual_info_score(y_gt, y_pred_i))) + + all_scores = all_scores_fac + all_scores_loss + max_score_idx = np.argmax(all_scores) + max_score = all_scores[max_score_idx] + + if max_score > curr_score: + medoid_ics[cluster_idx] = np.where( + y_pred == cluster_idx)[0][max_score_idx] + + def pam_augmented_fit(self, feat, y, loss_mult): + pam_max_iter = 5 + self._check_init_args() + feat = self._check_array(feat) + pdists = pairwise_distance_np(feat) + self.loss_augmented_fit(feat, y, loss_mult) + print('PAM -1 (before PAM): score: %f, score_aug: %f' % ( + self.score_, self.score_aug_)) + # Initialize from loss augmented facility location + subset = self.center_ics_ + for iter_ in range(pam_max_iter): + # update the cluster assignment + cluster_ics = self._get_cluster_ics(pdists, subset) + # update the medoid for each clusters + self._augmented_update_medoid_ics_in_place(pdists, y, cluster_ics, subset, + loss_mult) + self.score_ = np.float32(-1.0) * self._get_facility_distance( + pdists, subset) + self.score_aug_ = self.score_ + loss_mult * ( + 1.0 - metrics.normalized_mutual_info_score( + y, self._get_cluster_ics(pdists, subset))) + self.score_aug_ = self.score_aug_.astype(np.float32) + print('PAM iter: %d, score: %f, score_aug: %f' % (iter_, self.score_, + self.score_aug_)) + + self.center_ics_ = subset + self.labels_ = cluster_ics + return self + + def _check_array(self, feat): + # Check that the number of clusters is less than or equal to + # the number of samples + if self.n_clusters > feat.shape[0]: + raise ValueError('The number of medoids ' + '({}) '.format( + self.n_clusters) + 'must be larger than the number ' + + 'of samples ({})'.format(feat.shape[0])) + return feat + + def _get_cluster_ics(self, pdists, subset): + """Returns cluster indices for pdist and current medoid indices.""" + # Assign data points to clusters based on + # which cluster assignment yields + # the smallest distance` + cluster_ics = np.argmin(pdists[subset, :], axis=0) + return cluster_ics + + def _get_facility_distance(self, pdists, subset): + return np.sum(np.min(pdists[subset, :], axis=0)) + + +class ClusterLossTest(test.TestCase): + + def _genClusters(self, n_samples, n_clusters): + blobs = datasets.make_blobs( + n_samples=n_samples, centers=n_clusters) + embedding, labels = blobs + embedding = (embedding - embedding.mean(axis=0)) / embedding.std(axis=0) + embedding = embedding.astype(np.float32) + return embedding, labels + + def testClusteringLossPAMOff(self): + if not HAS_SKLEARN: + return + with self.test_session(): + margin_multiplier = 10.0 + embeddings, labels = self._genClusters(n_samples=128, n_clusters=64) + + loss_np = compute_cluster_loss_numpy( + embeddings, labels, margin_multiplier, enable_pam_finetuning=False) + loss_tf = metric_loss_ops.cluster_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embeddings), + margin_multiplier=margin_multiplier, + enable_pam_finetuning=False) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + def testClusteringLossPAMOn(self): + if not HAS_SKLEARN: + return + with self.test_session(): + margin_multiplier = 10.0 + embeddings, labels = self._genClusters(n_samples=128, n_clusters=64) + + loss_np = compute_cluster_loss_numpy( + embeddings, labels, margin_multiplier, enable_pam_finetuning=True) + loss_tf = metric_loss_ops.cluster_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embeddings), + margin_multiplier=margin_multiplier, + enable_pam_finetuning=True) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/makefile/BUILD b/tensorflow/contrib/makefile/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a8dd59f32a7f3b27993a7ee48ee7cc07ada59a4c --- /dev/null +++ b/tensorflow/contrib/makefile/BUILD @@ -0,0 +1,31 @@ +# Necessary build rules for makefile build in our CI. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = ["**/OWNERS"], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +sh_test( + name = "build_all_linux", + size = "enormous", + srcs = ["build_all_linux.sh"], + data = [ + "//tensorflow:all_opensource_files", + "//third_party/eigen3:all_files", + "//third_party/fft2d:all_files", + ], + tags = [ + "manual", + "no_gpu", + "no_oss", + "notap", + ], +) diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index e0cfab0b26d8f106e83f6223d057c9ef5f395f4f..3b4d0ff799c05ce34cc55385ccc637467e443e40 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -14,7 +14,10 @@ # Host compilation settings # Find where we're running from, so we can store generated files here. -MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) +ifeq ($(origin MAKEFILE_DIR), undefined) + MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) +endif + HAS_GEN_HOST_PROTOC := \ $(shell test -f $(MAKEFILE_DIR)/gen/protobuf-host/bin/protoc && echo "true" ||\ echo "false") @@ -41,6 +44,11 @@ ifdef HEXAGON_LIBS endif endif # HEXAGON_LIBS +# If ANDROID_TYPES is not set assume __ANDROID_TYPES_SLIM__ +ifeq ($(ANDROID_TYPES),) + ANDROID_TYPES := -D__ANDROID_TYPES_SLIM__ +endif + # Try to figure out the host system HOST_OS := ifeq ($(OS),Windows_NT) @@ -71,6 +79,7 @@ HOST_LDOPTS += -L/usr/local/lib HOST_INCLUDES := \ -I. \ +-I$(MAKEFILE_DIR)/../../../ \ -I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/eigen \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ @@ -190,6 +199,10 @@ LIBFLAGS := # If we're on OS X, make sure that globals aren't stripped out. ifeq ($(TARGET),OSX) +ifeq ($(HAS_GEN_HOST_PROTOC),true) + LIBFLAGS += -L$(MAKEFILE_DIR)/gen/protobuf-host/lib + export LD_LIBRARY_PATH=$(MAKEFILE_DIR)/gen/protobuf-host/lib +endif LDFLAGS += -all_load endif # Make sure that we don't strip global constructors on Linux. @@ -208,7 +221,7 @@ ifeq ($(TARGET),LINUX) endif # If we're cross-compiling for the Raspberry Pi, use the right gcc. ifeq ($(TARGET),PI) - CXXFLAGS += -D__ANDROID_TYPES_SLIM__ -DRASPBERRY_PI + CXXFLAGS += $(ANDROID_TYPES) -DRASPBERRY_PI LDFLAGS := -Wl,--no-whole-archive LIBS += -ldl -lpthread LIBFLAGS += -Wl,--allow-multiple-definition -Wl,--whole-archive @@ -330,7 +343,7 @@ ifeq ($(TARGET),IOS) -Wno-c++11-narrowing \ -mno-thumb \ -DTF_LEAN_BINARY \ - -D__ANDROID_TYPES_SLIM__ \ + $(ANDROID_TYPES) \ -fno-exceptions \ -isysroot \ ${IPHONEOS_SYSROOT} @@ -354,7 +367,7 @@ ifeq ($(TARGET),IOS) -Wno-c++11-narrowing \ -mno-thumb \ -DTF_LEAN_BINARY \ - -D__ANDROID_TYPES_SLIM__ \ + $(ANDROID_TYPES) \ -fno-exceptions \ -isysroot \ ${IPHONEOS_SYSROOT} @@ -377,7 +390,7 @@ ifeq ($(TARGET),IOS) -DUSE_GEMM_FOR_CONV \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ - -D__ANDROID_TYPES_SLIM__ \ + $(ANDROID_TYPES) \ -fno-exceptions \ -isysroot \ ${IPHONEOS_SYSROOT} @@ -401,7 +414,7 @@ ifeq ($(TARGET),IOS) -DUSE_GEMM_FOR_CONV \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ - -D__ANDROID_TYPES_SLIM__ \ + $(ANDROID_TYPES) \ -fno-exceptions \ -isysroot \ ${IPHONESIMULATOR_SYSROOT} @@ -424,7 +437,7 @@ ifeq ($(TARGET),IOS) -DUSE_GEMM_FOR_CONV \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ - -D__ANDROID_TYPES_SLIM__ \ + $(ANDROID_TYPES) \ -fno-exceptions \ -isysroot \ ${IPHONESIMULATOR_SYSROOT} @@ -484,6 +497,7 @@ $(wildcard tensorflow/core/*/*/*main.cc) \ $(wildcard tensorflow/core/debug/*.cc) \ $(wildcard tensorflow/core/framework/op_gen_lib.cc) \ $(wildcard tensorflow/core/graph/dot.*) \ +$(wildcard tensorflow/core/lib/db/*) \ $(wildcard tensorflow/core/lib/gif/*) \ $(wildcard tensorflow/core/lib/io/zlib*) \ $(wildcard tensorflow/core/lib/io/record*) \ @@ -501,6 +515,7 @@ $(wildcard tensorflow/core/platform/google/*) \ $(wildcard tensorflow/core/platform/google/*/*) \ $(wildcard tensorflow/core/platform/jpeg.*) \ $(wildcard tensorflow/core/platform/png.*) \ +$(wildcard tensorflow/core/platform/s3/*) \ $(wildcard tensorflow/core/platform/stream_executor.*) \ $(wildcard tensorflow/core/platform/windows/*) \ $(wildcard tensorflow/core/user_ops/*.cu.cc) \ diff --git a/tensorflow/contrib/makefile/build_all_linux.sh b/tensorflow/contrib/makefile/build_all_linux.sh index 5d73f697f4ef0b2a566deb04397b0def5a442cfa..a440633cfc23a7c606586a3b53180aaed6fe27ad 100755 --- a/tensorflow/contrib/makefile/build_all_linux.sh +++ b/tensorflow/contrib/makefile/build_all_linux.sh @@ -44,4 +44,5 @@ tensorflow/contrib/makefile/compile_linux_protobuf.sh # Build TensorFlow. make -j"${JOB_COUNT}" -f tensorflow/contrib/makefile/Makefile \ OPTFLAGS="-O3 -march=native" \ - HOST_CXXFLAGS="--std=c++11 -march=native" + HOST_CXXFLAGS="--std=c++11 -march=native" \ + MAKEFILE_DIR=$SCRIPT_DIR diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 39c89628d96ad1d7d8a28ec76071d4aa31085225..12e3f589306d54b10b38a48d8aed356de4ddc91b 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -20,11 +20,11 @@ DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads BZL_FILE_PATH=tensorflow/workspace.bzl EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" -GEMMLOWP_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" -NSYNC_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -PROTOBUF_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -RE2_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, @@ -54,7 +54,7 @@ download_and_extract() { elif [[ "${url}" == *zip ]]; then tempdir=$(mktemp -d) tempdir2=$(mktemp -d) - wget ${url} -P ${tempdir} + wget -P ${tempdir} ${url} unzip ${tempdir}/* -d ${tempdir2} # unzip has no strip components, so unzip to a temp dir, and move the files # we want from the tempdir to destination. diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index 5ade8942af39f1d308c5f6e308e1cee754510926..938c4a53ab3fff72b028276eac5aad76ff01880d 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -24,6 +24,7 @@ tensorflow/core/framework/summary.pb.cc tensorflow/core/framework/step_stats.pb.cc tensorflow/core/framework/resource_handle.pb.cc tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc +tensorflow/core/framework/api_def.pb.cc tensorflow/core/framework/op_def.pb.cc tensorflow/core/framework/node_def.pb.cc tensorflow/core/framework/log_memory.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 1f0ad06cdc5b98ae9c08ea63dad70eb02b6ef46b..aa91b2f954504c42d33838c728abd666ef100e14 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -25,6 +25,7 @@ tensorflow/core/framework/summary.pb.h tensorflow/core/framework/step_stats.pb.h tensorflow/core/framework/resource_handle.pb.h tensorflow/core/framework/remote_fused_graph_execute_info.pb.h +tensorflow/core/framework/api_def.pb.h tensorflow/core/framework/op_def.pb.h tensorflow/core/framework/node_def.pb.h tensorflow/core/framework/log_memory.pb.h diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 1fda907074545d9b78a902182e4cec9e4212c22d..8b77c99cb574123c2af5d8f9f17cd403613cfffd 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -170,6 +170,8 @@ tensorflow/core/kernels/cwise_op_div.cc tensorflow/core/kernels/cwise_op_bitwise_xor.cc tensorflow/core/kernels/cwise_op_bitwise_or.cc tensorflow/core/kernels/cwise_op_bitwise_and.cc +tensorflow/core/kernels/cwise_op_left_shift.cc +tensorflow/core/kernels/cwise_op_right_shift.cc tensorflow/core/kernels/cwise_op_add_2.cc tensorflow/core/kernels/cwise_op_add_1.cc tensorflow/core/kernels/cwise_op_abs.cc @@ -262,3 +264,4 @@ tensorflow/core/kernels/spacetobatch_functor.cc tensorflow/core/kernels/spacetobatch_op.cc tensorflow/core/kernels/batchtospace_op.cc tensorflow/core/kernels/warn_about_ints.cc +tensorflow/core/kernels/segment_reduction_ops.cc diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index c39257ffa91fef184e8bd5258b19c4323a1b7fe0..b5431df2eb016d010c51bdbb33fd747b3569ce83 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -17,6 +17,7 @@ tensorflow/core/framework/summary.pb_text.cc tensorflow/core/framework/step_stats.pb_text.cc tensorflow/core/framework/resource_handle.pb_text.cc tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc +tensorflow/core/framework/api_def.pb_text.cc tensorflow/core/framework/op_def.pb_text.cc tensorflow/core/framework/node_def.pb_text.cc tensorflow/core/framework/log_memory.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index a1a9aa7190205d9f3c34ef01b65db85f89f2ac85..d569bde637b20e0ca55c48c616855332abd9fb13 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -30,6 +30,7 @@ tensorflow/core/framework/step_stats.proto tensorflow/core/framework/resource_handle.proto tensorflow/core/framework/remote_fused_graph_execute_info.proto tensorflow/core/framework/reader_base.proto +tensorflow/core/framework/api_def.proto tensorflow/core/framework/op_def.proto tensorflow/core/framework/node_def.proto tensorflow/core/framework/log_memory.proto diff --git a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc index dd479147749e286caa34aa658a8a51f865bcee20..7e2e96e160167ae68d3bdabacbbbeb45df61778f 100644 --- a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc +++ b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc @@ -54,15 +54,13 @@ class BytesInUseOp : public MemoryStatsOp { }; // Register this op on GPU only, see comment for MaxBytesInUse for reason -REGISTER_KERNEL_BUILDER( - Name("BytesInUse").Device(DEVICE_GPU).HostMemory("out"), - BytesInUseOp); +REGISTER_KERNEL_BUILDER(Name("BytesInUse").Device(DEVICE_GPU).HostMemory("out"), + BytesInUseOp); #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER( - Name("BytesInUse").Device(DEVICE_SYCL).HostMemory("out"), - MaxBytesInUseOp); -#endif // TENSORFLOW_USE_SYCL + Name("BytesInUse").Device(DEVICE_SYCL).HostMemory("out"), MaxBytesInUseOp); +#endif // TENSORFLOW_USE_SYCL // Op that measures the total memory (in bytes) of a device. class BytesLimitOp : public MemoryStatsOp { diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py index 303c02dfa409bb7410233f0005d9f3cb0b5bc11e..2932ae1c8df32cd936cff932b061571c513fda79 100644 --- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py +++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py @@ -749,7 +749,7 @@ def meta_graph_transform( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) - # Append newly added initalizers to collection. + # Append newly added initializers to collection. _add_new_inits_to_collection(meta_graph_def, updated_initializer_names) # Copy signature_defs, excluding any pruned nodes diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index a9bce65e55f36da0b930fcff619e9ed8841ef4c2..bb566f69029b4cd3b530c31bda22d78a19d9bf02 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -22,6 +22,10 @@ See the @{$python/contrib.metrics} guide. @@streaming_recall_at_thresholds @@streaming_precision @@streaming_precision_at_thresholds +@@streaming_false_positive_rate +@@streaming_false_positive_rate_at_thresholds +@@streaming_false_negative_rate +@@streaming_false_negative_rate_at_thresholds @@streaming_auc @@streaming_curve_points @@streaming_recall_at_k @@ -61,6 +65,7 @@ See the @{$python/contrib.metrics} guide. @@set_intersection @@set_size @@set_union +@@count """ from __future__ import absolute_import @@ -74,14 +79,19 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics +from tensorflow.contrib.metrics.python.ops.metric_ops import count from tensorflow.contrib.metrics.python.ops.metric_ops import sparse_recall_at_top_k from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positive_rate +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positive_rate_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_mean diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 76986d0156dada75abcc559d9db6b9addf26cccc..177c4c53f7ce321ac542e6767c499b314e96adb7 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -22,11 +22,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections as collections_lib + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops -from tensorflow.python.ops import confusion_matrix from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -56,7 +57,10 @@ def _safe_div(numerator, denominator, name): name=name) -def _create_local(name, shape, collections=None, validate_shape=True, +def _create_local(name, + shape, + collections=None, + validate_shape=True, dtype=dtypes.float32): """Creates a new local variable. @@ -87,7 +91,9 @@ def _assert_weights_rank(weights, values): return check_ops.assert_rank_in(weights, (0, array_ops.rank(values))) -def _count_condition(values, weights=None, metrics_collections=None, +def _count_condition(values, + weights=None, + metrics_collections=None, updates_collections=None): """Sums the weights of cases where the given values are True. @@ -114,7 +120,7 @@ def _count_condition(values, weights=None, metrics_collections=None, or tuple. """ check_ops.assert_type(values, dtypes.bool) - count = _create_local('count', shape=[]) + count_ = _create_local('count', shape=[]) values = math_ops.to_float(values) if weights is not None: @@ -122,8 +128,8 @@ def _count_condition(values, weights=None, metrics_collections=None, with ops.control_dependencies((_assert_weights_rank(weights, values),)): values = math_ops.multiply(values, weights) - value_tensor = array_ops.identity(count) - update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) + value_tensor = array_ops.identity(count_) + update_op = state_ops.assign_add(count_, math_ops.reduce_sum(values)) if metrics_collections: ops.add_to_collections(metrics_collections, value_tensor) @@ -134,7 +140,9 @@ def _count_condition(values, weights=None, metrics_collections=None, return value_tensor, update_op -def streaming_true_positives(predictions, labels, weights=None, +def streaming_true_positives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -168,12 +176,17 @@ def streaming_true_positives(predictions, labels, weights=None, tuple. """ return metrics.true_positives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_true_negatives(predictions, labels, weights=None, +def streaming_true_negatives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -206,20 +219,22 @@ def streaming_true_negatives(predictions, labels, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'true_negatives', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'true_negatives', + (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) - is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), - math_ops.equal(predictions, False)) + is_true_negative = math_ops.logical_and( + math_ops.equal(labels, False), math_ops.equal(predictions, False)) return _count_condition(is_true_negative, weights, metrics_collections, updates_collections) -def streaming_false_positives(predictions, labels, weights=None, +def streaming_false_positives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -253,12 +268,17 @@ def streaming_false_positives(predictions, labels, weights=None, tuple. """ return metrics.false_positives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_false_negatives(predictions, labels, weights=None, +def streaming_false_negatives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -291,9 +311,12 @@ def streaming_false_negatives(predictions, labels, weights=None, or tuple. """ return metrics.false_negatives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. @@ -317,17 +340,18 @@ def _broadcast_weights(weights, values): with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope: weights_shape = weights.get_shape() values_shape = values.get_shape() - if (weights_shape.is_fully_defined() and - values_shape.is_fully_defined() and + if (weights_shape.is_fully_defined() and values_shape.is_fully_defined() and weights_shape.is_compatible_with(values_shape)): return weights with ops.control_dependencies((_assert_weights_rank(weights, values),)): - return math_ops.multiply( - weights, array_ops.ones_like(values), name=scope) + return math_ops.multiply(weights, array_ops.ones_like(values), name=scope) -def streaming_mean(values, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_mean(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the (weighted) mean of the given values. The `streaming_mean` function creates two local variables, `total` and `count` @@ -365,12 +389,18 @@ def streaming_mean(values, weights=None, metrics_collections=None, or tuple. """ return metrics.mean( - values=values, weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + values=values, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_mean_tensor(values, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_mean_tensor(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the element-wise (weighted) mean of the given tensors. In contrast to the `streaming_mean` function which returns a scalar with the @@ -412,12 +442,18 @@ def streaming_mean_tensor(values, weights=None, metrics_collections=None, or tuple. """ return metrics.mean_tensor( - values=values, weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + values=values, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_accuracy(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_accuracy(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Calculates how often `predictions` matches `labels`. @@ -462,13 +498,19 @@ def streaming_accuracy(predictions, labels, weights=None, tuple. """ return metrics.accuracy( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_precision(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_precision(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the precision of the predictions with respect to the labels. @@ -512,13 +554,19 @@ def streaming_precision(predictions, labels, weights=None, tuple. """ return metrics.precision( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_recall(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_recall(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the recall of the predictions with respect to the labels. @@ -560,13 +608,242 @@ def streaming_recall(predictions, labels, weights=None, tuple. """ return metrics.recall( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) + +def _true_negatives(labels, + predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Sum the weights of true negatives. -def _streaming_confusion_matrix_at_thresholds( - predictions, labels, thresholds, weights=None, includes=None): + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will + be cast to `bool`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric + value variable should be added to. + updates_collections: An optional list of collections that the metric update + ops should be added to. + name: An optional variable_scope name. + + Returns: + value_tensor: A `Tensor` representing the current value of the metric. + update_op: An operation that accumulates the error from a batch of data. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'true_negatives', + (predictions, labels, weights)): + + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + is_true_negative = math_ops.logical_and( + math_ops.equal(labels, False), math_ops.equal(predictions, False)) + return _count_condition(is_true_negative, weights, metrics_collections, + updates_collections) + + +def streaming_false_positive_rate(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes the false positive rate of predictions with respect to labels. + + The `false_positive_rate` function creates two local variables, + `false_positives` and `true_negatives`, that are used to compute the + false positive rate. This value is ultimately returned as + `false_positive_rate`, an idempotent operation that simply divides + `false_positives` by the sum of `false_positives` and `true_negatives`. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_positive_rate`. `update_op` weights each prediction by the + corresponding value in `weights`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will + be cast to `bool`. + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_positive_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_positive_rate: Scalar float `Tensor` with the value of + `false_positives` divided by the sum of `false_positives` and + `true_negatives`. + update_op: `Operation` that increments `false_positives` and + `true_negatives` variables appropriately and whose value matches + `false_positive_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'false_positive_rate', + (predictions, labels, weights)): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + + false_p, false_positives_update_op = metrics.false_positives( + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) + true_n, true_negatives_update_op = _true_negatives( + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) + + def compute_fpr(fp, tn, name): + return array_ops.where( + math_ops.greater(fp + tn, 0), math_ops.div(fp, fp + tn), 0, name) + + fpr = compute_fpr(false_p, true_n, 'value') + update_op = compute_fpr(false_positives_update_op, true_negatives_update_op, + 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fpr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return fpr, update_op + + +def streaming_false_negative_rate(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes the false negative rate of predictions with respect to labels. + + The `false_negative_rate` function creates two local variables, + `false_negatives` and `true_positives`, that are used to compute the + false positive rate. This value is ultimately returned as + `false_negative_rate`, an idempotent operation that simply divides + `false_negatives` by the sum of `false_negatives` and `true_positives`. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_negative_rate`. `update_op` weights each prediction by the + corresponding value in `weights`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will + be cast to `bool`. + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_negative_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_negative_rate: Scalar float `Tensor` with the value of + `false_negatives` divided by the sum of `false_negatives` and + `true_positives`. + update_op: `Operation` that increments `false_negatives` and + `true_positives` variables appropriately and whose value matches + `false_negative_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'false_negative_rate', + (predictions, labels, weights)): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + + false_n, false_negatives_update_op = metrics.false_negatives( + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) + true_p, true_positives_update_op = metrics.true_positives( + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) + + def compute_fnr(fn, tp, name): + return array_ops.where( + math_ops.greater(fn + tp, 0), math_ops.div(fn, fn + tp), 0, name) + + fnr = compute_fnr(false_n, true_p, 'value') + update_op = compute_fnr(false_negatives_update_op, true_positives_update_op, + 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fnr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return fnr, update_op + + +def _streaming_confusion_matrix_at_thresholds(predictions, + labels, + thresholds, + weights=None, + includes=None): """Computes true_positives, false_negatives, true_negatives, false_positives. This function creates up to four local variables, `true_positives`, @@ -618,7 +895,7 @@ def _streaming_confusion_matrix_at_thresholds( if include not in all_includes: raise ValueError('Invaild key: %s.' % include) - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -654,8 +931,8 @@ def _streaming_confusion_matrix_at_thresholds( if weights is not None: broadcast_weights = weights_broadcast_ops.broadcast_weights( math_ops.to_float(weights), predictions) - weights_tiled = array_ops.tile(array_ops.reshape( - broadcast_weights, [1, -1]), [num_thresholds, 1]) + weights_tiled = array_ops.tile( + array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1]) thresh_tiled.get_shape().assert_is_compatible_with( weights_tiled.get_shape()) else: @@ -670,8 +947,9 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_pos, pred_is_pos)) if weights_tiled is not None: is_true_positive *= weights_tiled - update_ops['tp'] = state_ops.assign_add( - true_positives, math_ops.reduce_sum(is_true_positive, 1)) + update_ops['tp'] = state_ops.assign_add(true_positives, + math_ops.reduce_sum( + is_true_positive, 1)) values['tp'] = true_positives if 'fn' in includes: @@ -680,8 +958,9 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_pos, pred_is_neg)) if weights_tiled is not None: is_false_negative *= weights_tiled - update_ops['fn'] = state_ops.assign_add( - false_negatives, math_ops.reduce_sum(is_false_negative, 1)) + update_ops['fn'] = state_ops.assign_add(false_negatives, + math_ops.reduce_sum( + is_false_negative, 1)) values['fn'] = false_negatives if 'tn' in includes: @@ -690,8 +969,9 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_neg, pred_is_neg)) if weights_tiled is not None: is_true_negative *= weights_tiled - update_ops['tn'] = state_ops.assign_add( - true_negatives, math_ops.reduce_sum(is_true_negative, 1)) + update_ops['tn'] = state_ops.assign_add(true_negatives, + math_ops.reduce_sum( + is_true_negative, 1)) values['tn'] = true_negatives if 'fp' in includes: @@ -700,36 +980,45 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_neg, pred_is_pos)) if weights_tiled is not None: is_false_positive *= weights_tiled - update_ops['fp'] = state_ops.assign_add( - false_positives, math_ops.reduce_sum(is_false_positive, 1)) + update_ops['fp'] = state_ops.assign_add(false_positives, + math_ops.reduce_sum( + is_false_positive, 1)) values['fp'] = false_positives return values, update_ops -def streaming_true_positives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_true_positives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('tp',)) return values['tp'], update_ops['tp'] -def streaming_false_negatives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_false_negatives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('fn',)) return values['fn'], update_ops['fn'] -def streaming_false_positives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_false_positives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('fp',)) return values['fp'], update_ops['fp'] -def streaming_true_negatives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_true_negatives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('tn',)) return values['tn'], update_ops['tn'] @@ -788,9 +1077,12 @@ def streaming_curve_points(labels=None, `weights` is not `None` and its shape doesn't match `predictions`, or if either `metrics_collections` or `updates_collections` are not a list or tuple. + + TODO(chizeng): Consider rewriting this method to make use of logic within the + streaming_precision_recall_at_equal_thresholds method (to improve run time). """ - with variable_scope.variable_scope(name, 'curve_points', (labels, predictions, - weights)): + with variable_scope.variable_scope(name, 'curve_points', + (labels, predictions, weights)): if curve != 'ROC' and curve != 'PR': raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) kepsilon = 1e-7 # to account for floating point imprecisions @@ -831,9 +1123,14 @@ def streaming_curve_points(labels=None, return points, update_op -def streaming_auc(predictions, labels, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, - curve='ROC', name=None): +def streaming_auc(predictions, + labels, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + curve='ROC', + name=None): """Computes the approximate AUC via a Riemann sum. The `streaming_auc` function creates four local variables, `true_positives`, @@ -890,14 +1187,201 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200, tuple. """ return metrics.auc( - predictions=predictions, labels=labels, weights=weights, - metrics_collections=metrics_collections, num_thresholds=num_thresholds, - curve=curve, updates_collections=updates_collections, name=name) + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + num_thresholds=num_thresholds, + curve=curve, + updates_collections=updates_collections, + name=name) + + +def streaming_precision_recall_at_equal_thresholds(predictions, + labels, + num_thresholds=None, + weights=None, + name=None, + use_locking=None): + """A helper method for creating metrics related to precision-recall curves. + + These values are true positives, false negatives, true negatives, false + positives, precision, and recall. This function returns a data structure that + contains ops within it. + + Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N) + space and run time), this op exhibits O(T + N) space and run time, where T is + the number of thresholds and N is the size of the predictions tensor. Hence, + it may be advantageous to use this function when `predictions` is big. + + For instance, prefer this method for per-pixel classification tasks, for which + the predictions tensor may be very large. + + Each number in `predictions`, a float in `[0, 1]`, is compared with its + corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at + each threshold. This is then multiplied with `weights` which can be used to + reweight certain values, or more commonly used for masking values. + + Args: + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + labels: A bool `Tensor` whose shape matches `predictions`. + num_thresholds: Optional; Number of thresholds, evenly distributed in + `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins + is 1 less than `num_thresholds`. Using an even `num_thresholds` value + instead of an odd one may yield unfriendly edges for bins. + weights: Optional; If provided, a `Tensor` that has the same dtype as, + and broadcastable to, `predictions`. This tensor is multplied by counts. + name: Optional; variable_scope name. If not provided, the string + 'precision_recall_at_equal_threshold' is used. + use_locking: Optional; If True, the op will be protected by a lock. + Otherwise, the behavior is undefined, but may exhibit less contention. + Defaults to True. + + Returns: + result: A named tuple (See PrecisionRecallData within the implementation of + this function) with properties that are variables of shape + `[num_thresholds]`. The names of the properties are tp, fp, tn, fn, + precision, recall, thresholds. + update_op: An op that accumulates values. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + `includes` contains invalid keys. + """ + # Disable the invalid-name checker so that we can capitalize the name. + # pylint: disable=invalid-name + PrecisionRecallData = collections_lib.namedtuple( + 'PrecisionRecallData', + ['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds']) + # pylint: enable=invalid-name + + if num_thresholds is None: + num_thresholds = 201 + + if weights is None: + weights = 1.0 + + if use_locking is None: + use_locking = True + + check_ops.assert_type(labels, dtypes.bool) + + dtype = predictions.dtype + with variable_scope.variable_scope(name, + 'precision_recall_at_equal_thresholds', + (labels, predictions, weights)): + # Make sure that predictions are within [0.0, 1.0]. + with ops.control_dependencies([ + check_ops.assert_greater_equal( + predictions, + math_ops.cast(0.0, dtype=predictions.dtype), + message='predictions must be in [0, 1]'), + check_ops.assert_less_equal( + predictions, + math_ops.cast(1.0, dtype=predictions.dtype), + message='predictions must be in [0, 1]') + ]): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=predictions, + labels=labels, + weights=weights) + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) -def streaming_specificity_at_sensitivity( - predictions, labels, sensitivity, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, name=None): + # We cast to float to ensure we have 0.0 or 1.0. + f_labels = math_ops.cast(labels, dtype) + + # Get weighted true/false labels. + true_labels = f_labels * weights + false_labels = (1.0 - f_labels) * weights + + # Flatten predictions and labels. + predictions = array_ops.reshape(predictions, [-1]) + true_labels = array_ops.reshape(true_labels, [-1]) + false_labels = array_ops.reshape(false_labels, [-1]) + + # To compute TP/FP/TN/FN, we are measuring a binary classifier + # C(t) = (predictions >= t) + # at each threshold 't'. So we have + # TP(t) = sum( C(t) * true_labels ) + # FP(t) = sum( C(t) * false_labels ) + # + # But, computing C(t) requires computation for each t. To make it fast, + # observe that C(t) is a cumulative integral, and so if we have + # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} + # where n = num_thresholds, and if we can compute the bucket function + # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) + # then we get + # C(t_i) = sum( B(j), j >= i ) + # which is the reversed cumulative sum in tf.cumsum(). + # + # We can compute B(i) efficiently by taking advantage of the fact that + # our thresholds are evenly distributed, in that + # width = 1.0 / (num_thresholds - 1) + # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] + # Given a prediction value p, we can map it to its bucket by + # bucket_index(p) = floor( p * (num_thresholds - 1) ) + # so we can use tf.scatter_add() to update the buckets in one pass. + # + # This implementation exhibits a run time and space complexity of O(T + N), + # where T is the number of thresholds and N is the size of predictions. + # Metrics that rely on _streaming_confusion_matrix_at_thresholds instead + # exhibit a complexity of O(T * N). + + # Compute the bucket indices for each prediction value. + bucket_indices = math_ops.cast( + math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32) + + with ops.name_scope('variables'): + tp_buckets_v = _create_local( + 'tp_buckets', shape=[num_thresholds], dtype=dtype) + fp_buckets_v = _create_local( + 'fp_buckets', shape=[num_thresholds], dtype=dtype) + + with ops.name_scope('update_op'): + update_tp = state_ops.scatter_add( + tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking) + update_fp = state_ops.scatter_add( + fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking) + + # Set up the cumulative sums to compute the actual metrics. + tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp') + fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp') + # fn = sum(true_labels) - tp + # = sum(tp_buckets) - tp + # = tp[0] - tp + # Similarly, + # tn = fp[0] - fp + tn = fp[0] - fp + fn = tp[0] - tp + + # We use a minimum to prevent division by 0. + epsilon = 1e-7 + precision = tp / math_ops.maximum(epsilon, tp + fp) + recall = tp / math_ops.maximum(epsilon, tp + fn) + + result = PrecisionRecallData( + tp=tp, + fp=fp, + tn=tn, + fn=fn, + precision=precision, + recall=recall, + thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds)) + update_op = control_flow_ops.group(update_tp, update_fp) + return result, update_op + + +def streaming_specificity_at_sensitivity(predictions, + labels, + sensitivity, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the specificity at a given sensitivity. The `streaming_specificity_at_sensitivity` function creates four local @@ -947,15 +1431,24 @@ def streaming_specificity_at_sensitivity( or `updates_collections` are not a list or tuple. """ return metrics.specificity_at_sensitivity( - sensitivity=sensitivity, num_thresholds=num_thresholds, - predictions=predictions, labels=labels, weights=weights, + sensitivity=sensitivity, + num_thresholds=num_thresholds, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_sensitivity_at_specificity( - predictions, labels, specificity, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, name=None): +def streaming_sensitivity_at_specificity(predictions, + labels, + specificity, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the sensitivity at a given specificity. The `streaming_sensitivity_at_specificity` function creates four local @@ -1005,16 +1498,23 @@ def streaming_sensitivity_at_specificity( or `updates_collections` are not a list or tuple. """ return metrics.sensitivity_at_specificity( - specificity=specificity, num_thresholds=num_thresholds, - predictions=predictions, labels=labels, weights=weights, + specificity=specificity, + num_thresholds=num_thresholds, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_precision_at_thresholds(predictions, labels, thresholds, +def streaming_precision_at_thresholds(predictions, + labels, + thresholds, weights=None, metrics_collections=None, - updates_collections=None, name=None): + updates_collections=None, + name=None): """Computes precision values for different `thresholds` on `predictions`. The `streaming_precision_at_thresholds` function creates four local variables, @@ -1059,14 +1559,21 @@ def streaming_precision_at_thresholds(predictions, labels, thresholds, """ return metrics.precision_at_thresholds( thresholds=thresholds, - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_recall_at_thresholds(predictions, labels, thresholds, - weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_recall_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes various recall values for different `thresholds` on `predictions`. The `streaming_recall_at_thresholds` function creates four local variables, @@ -1109,9 +1616,154 @@ def streaming_recall_at_thresholds(predictions, labels, thresholds, """ return metrics.recall_at_thresholds( thresholds=thresholds, - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) + + +def streaming_false_positive_rate_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes various fpr values for different `thresholds` on `predictions`. + + The `streaming_false_positive_rate_at_thresholds` function creates two + local variables, `false_positives`, `true_negatives`, for various values of + thresholds. `false_positive_rate[i]` is defined as the total weight + of values in `predictions` above `thresholds[i]` whose corresponding entry in + `labels` is `False`, divided by the total weight of `False` values in `labels` + (`false_positives[i] / (false_positives[i] + true_negatives[i])`). + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_positive_rate`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + labels: A `bool` `Tensor` whose shape matches `predictions`. + thresholds: A python list or tuple of float thresholds in `[0, 1]`. + weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and + must be broadcastable to `labels` (i.e., all dimensions must be either + `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_positive_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_positive_rate: A float `Tensor` of shape `[len(thresholds)]`. + update_op: An operation that increments the `false_positives` and + `true_negatives` variables that are used in the computation of + `false_positive_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'false_positive_rate_at_thresholds', + (predictions, labels, weights)): + values, update_ops = _streaming_confusion_matrix_at_thresholds( + predictions, labels, thresholds, weights, includes=('fp', 'tn')) + + # Avoid division by zero. + epsilon = 1e-7 + + def compute_fpr(fp, tn, name): + return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name) + + fpr = compute_fpr(values['fp'], values['tn'], 'value') + update_op = compute_fpr(update_ops['fp'], update_ops['tn'], 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fpr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return fpr, update_op + + +def streaming_false_negative_rate_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes various fnr values for different `thresholds` on `predictions`. + + The `streaming_false_negative_rate_at_thresholds` function creates two + local variables, `false_negatives`, `true_positives`, for various values of + thresholds. `false_negative_rate[i]` is defined as the total weight + of values in `predictions` above `thresholds[i]` whose corresponding entry in + `labels` is `False`, divided by the total weight of `True` values in `labels` + (`false_negatives[i] / (false_negatives[i] + true_positives[i])`). + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_positive_rate`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + labels: A `bool` `Tensor` whose shape matches `predictions`. + thresholds: A python list or tuple of float thresholds in `[0, 1]`. + weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and + must be broadcastable to `labels` (i.e., all dimensions must be either + `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_negative_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_negative_rate: A float `Tensor` of shape `[len(thresholds)]`. + update_op: An operation that increments the `false_negatives` and + `true_positives` variables that are used in the computation of + `false_negative_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'false_negative_rate_at_thresholds', + (predictions, labels, weights)): + values, update_ops = _streaming_confusion_matrix_at_thresholds( + predictions, labels, thresholds, weights, includes=('fn', 'tp')) + + # Avoid division by zero. + epsilon = 1e-7 + + def compute_fnr(fn, tp, name): + return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name) + + fnr = compute_fnr(values['fn'], values['tp'], 'value') + update_op = compute_fnr(update_ops['fn'], update_ops['tp'], 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fnr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return fnr, update_op def _at_k_name(name, k=None, class_id=None): @@ -1126,8 +1778,12 @@ def _at_k_name(name, k=None, class_id=None): @deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' 'and reshape labels from [batch_size] to [batch_size, 1].') -def streaming_recall_at_k(predictions, labels, k, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_recall_at_k(predictions, + labels, + k, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the recall@k of the predictions with respect to dense labels. @@ -1173,11 +1829,8 @@ def streaming_recall_at_k(predictions, labels, k, weights=None, tuple. """ in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) - return streaming_mean(in_top_k, - weights, - metrics_collections, - updates_collections, - name or _at_k_name('recall', k)) + return streaming_mean(in_top_k, weights, metrics_collections, + updates_collections, name or _at_k_name('recall', k)) # TODO(ptucker): Validate range of values in labels? @@ -1256,10 +1909,14 @@ def streaming_sparse_recall_at_k(predictions, are not a list or tuple. """ return metrics.recall_at_k( - k=k, class_id=class_id, - predictions=predictions, labels=labels, weights=weights, + k=k, + class_id=class_id, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Validate range of values in labels? @@ -1341,10 +1998,14 @@ def streaming_sparse_precision_at_k(predictions, are not a list or tuple. """ return metrics.sparse_precision_at_k( - k=k, class_id=class_id, - predictions=predictions, labels=labels, weights=weights, + k=k, + class_id=class_id, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Validate range of values in labels? @@ -1423,10 +2084,9 @@ def streaming_sparse_precision_at_top_k(top_k_predictions, ValueError: If `top_k_predictions` has rank < 2. """ default_name = _at_k_name('precision', class_id=class_id) - with ops.name_scope( - name, default_name, - (top_k_predictions, labels, weights)) as name_scope: - return metrics_impl._sparse_precision_at_top_k( # pylint: disable=protected-access + with ops.name_scope(name, default_name, + (top_k_predictions, labels, weights)) as name_scope: + return metrics_impl.precision_at_top_k( labels=labels, predictions_idx=top_k_predictions, class_id=class_id, @@ -1505,8 +2165,8 @@ def sparse_recall_at_top_k(labels, are not a list or tuple. """ default_name = _at_k_name('recall', class_id=class_id) - with ops.name_scope(name, default_name, (top_k_predictions, labels, - weights)) as name_scope: + with ops.name_scope(name, default_name, + (top_k_predictions, labels, weights)) as name_scope: return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access labels=labels, predictions_idx=top_k_predictions, @@ -1576,9 +2236,13 @@ def streaming_sparse_average_precision_at_k(predictions, value matches `metric`. """ return metrics.sparse_average_precision_at_k( - k=k, predictions=predictions, labels=labels, weights=weights, + k=k, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_sparse_average_precision_at_top_k(top_k_predictions, @@ -1644,7 +2308,9 @@ def streaming_sparse_average_precision_at_top_k(top_k_predictions, name=name) -def streaming_mean_absolute_error(predictions, labels, weights=None, +def streaming_mean_absolute_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1692,12 +2358,18 @@ def streaming_mean_absolute_error(predictions, labels, weights=None, tuple. """ return metrics.mean_absolute_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, +def streaming_mean_relative_error(predictions, + labels, + normalizer, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1746,12 +2418,18 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, tuple. """ return metrics.mean_relative_error( - normalizer=normalizer, predictions=predictions, labels=labels, - weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + normalizer=normalizer, + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_mean_squared_error(predictions, labels, weights=None, +def streaming_mean_squared_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1799,12 +2477,17 @@ def streaming_mean_squared_error(predictions, labels, weights=None, tuple. """ return metrics.mean_squared_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_root_mean_squared_error(predictions, labels, weights=None, +def streaming_root_mean_squared_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1852,9 +2535,12 @@ def streaming_root_mean_squared_error(predictions, labels, weights=None, tuple. """ return metrics.root_mean_squared_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_covariance(predictions, @@ -1910,12 +2596,12 @@ def streaming_covariance(predictions, ValueError: If labels and predictions are of different sizes or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'covariance', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + with variable_scope.variable_scope(name, 'covariance', + (predictions, labels, weights)): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - count = _create_local('count', []) + count_ = _create_local('count', []) mean_prediction = _create_local('mean_prediction', []) mean_label = _create_local('mean_label', []) comoment = _create_local('comoment', []) # C_A in update equation @@ -1930,7 +2616,7 @@ def streaming_covariance(predictions, weighted_predictions = math_ops.multiply(predictions, weights) weighted_labels = math_ops.multiply(labels, weights) - update_count = state_ops.assign_add(count, batch_count) # n_AB in eqn + update_count = state_ops.assign_add(count_, batch_count) # n_AB in eqn prev_count = update_count - batch_count # n_A in update equation # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) @@ -1955,34 +2641,34 @@ def streaming_covariance(predictions, # prev_mean_label is E[y_A] in the update equation prev_mean_label = update_mean_label - delta_mean_label - unweighted_batch_coresiduals = ( - (predictions - batch_mean_prediction) * (labels - batch_mean_label)) + unweighted_batch_coresiduals = ((predictions - batch_mean_prediction) * + (labels - batch_mean_label)) # batch_comoment is C_B in the update equation if weights is None: batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) else: - batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals * - weights) + batch_comoment = math_ops.reduce_sum( + unweighted_batch_coresiduals * weights) # View delta_comoment as = C_AB - C_A in the update equation above. # Since C_A is stored in a var, by how much do we need to increment that var # to make the var = C_AB? - delta_comoment = (batch_comoment + - (prev_mean_prediction - batch_mean_prediction) * - (prev_mean_label - batch_mean_label) * - (prev_count * batch_count / update_count)) + delta_comoment = ( + batch_comoment + (prev_mean_prediction - batch_mean_prediction) * + (prev_mean_label - batch_mean_label) * + (prev_count * batch_count / update_count)) update_comoment = state_ops.assign_add(comoment, delta_comoment) covariance = array_ops.where( - math_ops.less_equal(count, 1.), + math_ops.less_equal(count_, 1.), float('nan'), - math_ops.truediv(comoment, count - 1), + math_ops.truediv(comoment, count_ - 1), name='covariance') with ops.control_dependencies([update_comoment]): update_op = array_ops.where( - math_ops.less_equal(count, 1.), + math_ops.less_equal(count_, 1.), float('nan'), - math_ops.truediv(comoment, count - 1), + math_ops.truediv(comoment, count_ - 1), name='update_op') if metrics_collections: @@ -2044,9 +2730,9 @@ def streaming_pearson_correlation(predictions, `weights` is the wrong size, or if either `metrics_collections` or `updates_collections` are not a `list` or `tuple`. """ - with variable_scope.variable_scope( - name, 'pearson_r', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + with variable_scope.variable_scope(name, 'pearson_r', + (predictions, labels, weights)): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) # Broadcast weights here to avoid duplicate broadcasting in each call to @@ -2062,13 +2748,14 @@ def streaming_pearson_correlation(predictions, pearson_r = math_ops.truediv( cov, - math_ops.multiply(math_ops.sqrt(var_predictions), - math_ops.sqrt(var_labels)), + math_ops.multiply( + math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)), name='pearson_r') update_op = math_ops.truediv( update_cov, - math_ops.multiply(math_ops.sqrt(update_var_predictions), - math_ops.sqrt(update_var_labels)), + math_ops.multiply( + math_ops.sqrt(update_var_predictions), + math_ops.sqrt(update_var_labels)), name='update_op') if metrics_collections: @@ -2082,7 +2769,10 @@ def streaming_pearson_correlation(predictions, # TODO(nsilberman): add a 'normalized' flag so that the user can request # normalization if the inputs are not normalized. -def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, +def streaming_mean_cosine_distance(predictions, + labels, + dim, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2124,16 +2814,15 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.multiply(predictions, labels) - radial_diffs = math_ops.reduce_sum(radial_diffs, - reduction_indices=[dim,], - keep_dims=True) - mean_distance, update_op = streaming_mean(radial_diffs, weights, - None, - None, + radial_diffs = math_ops.reduce_sum( + radial_diffs, reduction_indices=[ + dim, + ], keep_dims=True) + mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None, name or 'mean_cosine_distance') mean_distance = math_ops.subtract(1.0, mean_distance) update_op = math_ops.subtract(1.0, update_op) @@ -2147,7 +2836,9 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, return mean_distance, update_op -def streaming_percentage_less(values, threshold, weights=None, +def streaming_percentage_less(values, + threshold, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2187,9 +2878,12 @@ def streaming_percentage_less(values, threshold, weights=None, or tuple. """ return metrics.percentage_below( - values=values, threshold=threshold, weights=weights, + values=values, + threshold=threshold, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_mean_iou(predictions, @@ -2241,9 +2935,13 @@ def streaming_mean_iou(predictions, tuple. """ return metrics.mean_iou( - num_classes=num_classes, predictions=predictions, labels=labels, - weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + num_classes=num_classes, + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) def _next_array_size(required_size, growth_factor=1.5): @@ -2258,9 +2956,9 @@ def _next_array_size(required_size, growth_factor=1.5): tf.Tensor with dtype=int32 giving the next array size. """ exponent = math_ops.ceil( - math_ops.log(math_ops.cast(required_size, dtypes.float32)) - / math_ops.log(math_ops.cast(growth_factor, dtypes.float32))) - return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32) + math_ops.log(math_ops.cast(required_size, dtypes.float32)) / math_ops.log( + math_ops.cast(growth_factor, dtypes.float32))) + return math_ops.cast(math_ops.ceil(growth_factor**exponent), dtypes.int32) def streaming_concat(values, @@ -2317,8 +3015,7 @@ def streaming_concat(values, if not 0 <= axis < ndim: raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) - fixed_shape = [dim.value for n, dim in enumerate(values_shape) - if n != axis] + fixed_shape = [dim.value for n, dim in enumerate(values_shape) if n != axis] if any(value is None for value in fixed_shape): raise ValueError('all dimensions of `values` other than the dimension to ' 'concatenate along must have statically known size') @@ -2427,60 +3124,81 @@ def aggregate_metric_map(names_to_tuples): return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) -def _remove_squeezable_dimensions(predictions, labels, weights): - """Squeeze last dim if needed. +def count(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes the number of examples, or sum of `weights`. - Squeezes `predictions` and `labels` if their rank differs by 1. - Squeezes `weights` if its rank is 1 more than the new rank of `predictions` + When evaluating some metric (e.g. mean) on one or more subsets of the data, + this auxiliary metric is useful for keeping track of how many examples there + are in each subset. - This will use static shape if available. Otherwise, it will add graph - operations, which could result in a performance hit. + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. Args: - predictions: Predicted values, a `Tensor` of arbitrary dimensions. - labels: Label values, a `Tensor` whose dimensions match `predictions`. - weights: Optional weight `Tensor`. It will be squeezed if its rank is 1 - more than the new rank of `predictions` + values: A `Tensor` of arbitrary dimensions. Only it's shape is used. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions + must be either `1`, or the same as the corresponding `labels` + dimension). + metrics_collections: An optional list of collections that the metric + value variable should be added to. + updates_collections: An optional list of collections that the metric update + ops should be added to. + name: An optional variable_scope name. Returns: - Tuple of `predictions`, `labels` and `weights`, possibly with the last - dimension squeezed. + count: A `Tensor` representing the current value of the metric. + update_op: An operation that accumulates the metric from a batch of data. + + Raises: + ValueError: If `weights` is not `None` and its shape doesn't match `values`, + or if either `metrics_collections` or `updates_collections` are not a list + or tuple. """ - labels, predictions = confusion_matrix.remove_squeezable_dimensions( - labels, predictions) - predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - if weights is not None: - weights = ops.convert_to_tensor(weights) - predictions_shape = predictions.get_shape() - predictions_rank = predictions_shape.ndims - weights_shape = weights.get_shape() - weights_rank = weights_shape.ndims - - if (predictions_rank is not None) and (weights_rank is not None): - # Use static rank. - if weights_rank - predictions_rank == 1: - weights = array_ops.squeeze(weights, [-1]) - elif (weights_rank is None) or ( - weights_shape.dims[-1].is_compatible_with(1)): - # Use dynamic rank - weights = control_flow_ops.cond( - math_ops.equal(array_ops.rank(weights), - math_ops.add(array_ops.rank(predictions), 1)), - lambda: array_ops.squeeze(weights, [-1]), - lambda: weights) - return predictions, labels, weights + with variable_scope.variable_scope(name, 'count', (values, weights)): + count_ = _create_local('count', shape=[]) + + if weights is None: + num_values = math_ops.to_float(array_ops.size(values)) + else: + _, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=values, + labels=None, + weights=weights) + weights = weights_broadcast_ops.broadcast_weights( + math_ops.to_float(weights), values) + num_values = math_ops.reduce_sum(weights) + + with ops.control_dependencies([values]): + update_op = state_ops.assign_add(count_, num_values) + + if metrics_collections: + ops.add_to_collections(metrics_collections, count_) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return count_, update_op __all__ = [ 'aggregate_metric_map', 'aggregate_metrics', + 'count', 'sparse_recall_at_top_k', 'streaming_accuracy', 'streaming_auc', 'streaming_curve_points', + 'streaming_false_negative_rate', + 'streaming_false_negative_rate_at_thresholds', 'streaming_false_negatives', 'streaming_false_negatives_at_thresholds', + 'streaming_false_positive_rate', + 'streaming_false_positive_rate_at_thresholds', 'streaming_false_positives', 'streaming_false_positives_at_thresholds', 'streaming_mean', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 9b959b43a9db8baac5b37524e81bfbb11d6ad868..6a8284786f592b2fe840e3c68099fecc93dc91c6 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1101,7 +1101,7 @@ class StreamingPrecisionTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) precision, update_op = metrics.streaming_precision(predictions, labels) with self.test_session() as sess: @@ -1265,7 +1265,7 @@ class StreamingRecallTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) recall, update_op = metrics.streaming_recall(predictions, labels) with self.test_session() as sess: @@ -1355,6 +1355,262 @@ class StreamingRecallTest(test.TestCase): self.assertEqual(0, recall.eval()) +class StreamingFPRTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_false_positive_rate( + predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) + _assert_local_variables(self, ( + 'false_positive_rate/false_positives/count:0', + 'false_positive_rate/true_negatives/count:0')) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + mean, _ = metrics.streaming_false_positive_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [mean]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.streaming_false_positive_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(update_op) + + # Then verify idempotency. + initial_fpr = fpr.eval() + for _ in range(10): + self.assertEqual(initial_fpr, fpr.eval()) + + def testAllCorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(np_inputs) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fpr.eval()) + + def testSomeCorrect(self): + predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4)) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(0.5, update_op.eval()) + self.assertAlmostEqual(0.5, fpr.eval()) + + def testWeighted1d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[2], [5]]) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fp = 2.0 + 5.0 + weighted_f = (2.0 + 2.0) + (5.0 + 5.0) + expected_fpr = weighted_fp / weighted_f + self.assertAlmostEqual(expected_fpr, update_op.eval()) + self.assertAlmostEqual(expected_fpr, fpr.eval()) + + def testWeighted2d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fp = 1.0 + 3.0 + weighted_f = (1.0 + 4.0) + (2.0 + 3.0) + expected_fpr = weighted_fp / weighted_f + self.assertAlmostEqual(expected_fpr, update_op.eval()) + self.assertAlmostEqual(expected_fpr, fpr.eval()) + + def testAllIncorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(1 - np_inputs) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(1, fpr.eval()) + + def testZeroFalsePositivesAndTrueNegativesGivesZeroFPR(self): + predictions = array_ops.ones((1, 4)) + labels = array_ops.ones((1, 4)) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fpr.eval()) + + +class StreamingFNRTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_false_negative_rate( + predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) + _assert_local_variables(self, ( + 'false_negative_rate/false_negatives/count:0', + 'false_negative_rate/true_positives/count:0')) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + mean, _ = metrics.streaming_false_negative_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [mean]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.streaming_false_negative_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(update_op) + + # Then verify idempotency. + initial_fnr = fnr.eval() + for _ in range(10): + self.assertEqual(initial_fnr, fnr.eval()) + + def testAllCorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(np_inputs) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fnr.eval()) + + def testSomeCorrect(self): + predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4)) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(0.5, update_op.eval()) + self.assertAlmostEqual(0.5, fnr.eval()) + + def testWeighted1d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[2], [5]]) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fn = 2.0 + 5.0 + weighted_t = (2.0 + 2.0) + (5.0 + 5.0) + expected_fnr = weighted_fn / weighted_t + self.assertAlmostEqual(expected_fnr, update_op.eval()) + self.assertAlmostEqual(expected_fnr, fnr.eval()) + + def testWeighted2d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fn = 2.0 + 4.0 + weighted_t = (2.0 + 3.0) + (1.0 + 4.0) + expected_fnr = weighted_fn / weighted_t + self.assertAlmostEqual(expected_fnr, update_op.eval()) + self.assertAlmostEqual(expected_fnr, fnr.eval()) + + def testAllIncorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(1 - np_inputs) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(1, fnr.eval()) + + def testZeroFalseNegativesAndTruePositivesGivesZeroFNR(self): + predictions = array_ops.zeros((1, 4)) + labels = array_ops.zeros((1, 4)) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fnr.eval()) + + class StreamingCurvePointsTest(test.TestCase): def setUp(self): @@ -1481,7 +1737,7 @@ class StreamingAUCTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) auc, update_op = metrics.streaming_auc(predictions, labels) with self.test_session() as sess: @@ -1714,6 +1970,170 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(expected_auc, auc.eval(), 2) +class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def _testResultsEqual(self, expected_dict, gotten_result): + """Tests that 2 results (dicts) represent the same data. + + Args: + expected_dict: A dictionary with keys that are the names of properties + of PrecisionRecallData and whose values are lists of floats. + gotten_result: A PrecisionRecallData object. + """ + gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()} + self.assertItemsEqual( + list(expected_dict.keys()), list(gotten_dict.keys())) + + for key, expected_values in expected_dict.items(): + self.assertAllClose(expected_values, gotten_dict[key]) + + def _testCase(self, predictions, labels, expected_result, weights=None): + """Performs a test given a certain scenario of labels, predictions, weights. + + Args: + predictions: The predictions tensor. Of type float32. + labels: The labels tensor. Of type bool. + expected_result: The expected result (dict) that maps to tensors. + weights: Optional weights tensor. + """ + with self.test_session() as sess: + predictions_tensor = constant_op.constant( + predictions, dtype=dtypes_lib.float32) + labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) + weights_tensor = None + if weights: + weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32) + gotten_result, update_op = ( + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=predictions_tensor, + labels=labels_tensor, + num_thresholds=3, + weights=weights_tensor)) + + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + self._testResultsEqual(expected_result, gotten_result) + + def testVars(self): + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32), + labels=constant_op.constant([True], dtype=dtypes_lib.bool)) + _assert_local_variables( + self, + ( + 'precision_recall_at_equal_thresholds/variables/tp_buckets:0', + 'precision_recall_at_equal_thresholds/variables/fp_buckets:0' + )) + + def testVarsWithName(self): + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32), + labels=constant_op.constant([True], dtype=dtypes_lib.bool), + name='foo') + _assert_local_variables( + self, ('foo/variables/tp_buckets:0', 'foo/variables/fp_buckets:0')) + + def testValuesAreIdempotent(self): + predictions = constant_op.constant( + np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32) + labels = constant_op.constant( + np.random.uniform(size=(10, 3)) > 0.5, dtype=dtypes_lib.bool) + + result, update_op = ( + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=predictions, labels=labels)) + + with self.test_session() as sess: + # Run several updates. + sess.run(variables.local_variables_initializer()) + for _ in range(3): + sess.run(update_op) + + # Then verify idempotency. + initial_result = {k: value.eval().tolist() for k, value in + result._asdict().items()} + for _ in range(3): + self._testResultsEqual(initial_result, result) + + def testAllTruePositives(self): + self._testCase([[1]], [[True]], { + 'tp': [1, 1, 1], + 'fp': [0, 0, 0], + 'tn': [0, 0, 0], + 'fn': [0, 0, 0], + 'precision': [1.0, 1.0, 1.0], + 'recall': [1.0, 1.0, 1.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testAllTrueNegatives(self): + self._testCase([[0]], [[False]], { + 'tp': [0, 0, 0], + 'fp': [1, 0, 0], + 'tn': [0, 1, 1], + 'fn': [0, 0, 0], + 'precision': [0.0, 0.0, 0.0], + 'recall': [0.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testAllFalsePositives(self): + self._testCase([[1]], [[False]], { + 'tp': [0, 0, 0], + 'fp': [1, 1, 1], + 'tn': [0, 0, 0], + 'fn': [0, 0, 0], + 'precision': [0.0, 0.0, 0.0], + 'recall': [0.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testAllFalseNegatives(self): + self._testCase([[0]], [[True]], { + 'tp': [1, 0, 0], + 'fp': [0, 0, 0], + 'tn': [0, 0, 0], + 'fn': [0, 1, 1], + 'precision': [1.0, 0.0, 0.0], + 'recall': [1.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testManyValues(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], + { + 'tp': [4, 3, 0], + 'fp': [2, 0, 0], + 'tn': [0, 2, 2], + 'fn': [0, 1, 4], + 'precision': [2.0 / 3.0, 1.0, 0.0], + 'recall': [1.0, 0.75, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testManyValuesWithWeights(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], + { + 'tp': [1.5, 1.5, 0.0], + 'fp': [2.5, 0.0, 0.0], + 'tn': [0.0, 2.5, 2.5], + 'fn': [0.0, 0.0, 1.5], + 'precision': [0.375, 1.0, 0.0], + 'recall': [1.0, 1.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }, + weights=[[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]]) + + class StreamingSpecificityAtSensitivityTest(test.TestCase): def setUp(self): @@ -1753,7 +2173,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.7) @@ -1984,58 +2404,571 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): def testMetricsCollection(self): my_collection_name = '__metrics__' - prec, _ = metrics.streaming_precision_at_thresholds( - predictions=array_ops.ones((10, 1)), - labels=array_ops.ones((10, 1)), - thresholds=[0, 0.5, 1.0], - metrics_collections=[my_collection_name]) - rec, _ = metrics.streaming_recall_at_thresholds( + prec, _ = metrics.streaming_precision_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + metrics_collections=[my_collection_name]) + rec, _ = metrics.streaming_recall_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [prec, rec]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, precision_op = metrics.streaming_precision_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + updates_collections=[my_collection_name]) + _, recall_op = metrics.streaming_recall_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + updates_collections=[my_collection_name]) + self.assertListEqual( + ops.get_collection(my_collection_name), [precision_op, recall_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) + thresholds = [0, 0.5, 1.0] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run([prec_op, rec_op]) + + # Then verify idempotency. + initial_prec = prec.eval() + initial_rec = rec.eval() + for _ in range(10): + self.assertAllClose(initial_prec, prec.eval()) + self.assertAllClose(initial_rec, rec.eval()) + + # TODO(nsilberman): fix tests (passing but incorrect). + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(inputs) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertEqual(1, prec.eval()) + self.assertEqual(1, rec.eval()) + + def testSomeCorrect(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0.5, prec.eval()) + self.assertAlmostEqual(0.5, rec.eval()) + + def testAllIncorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0, prec.eval()) + self.assertAlmostEqual(0, rec.eval()) + + def testWeights1d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds, weights=weights) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds, weights=weights) + + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(1.0, prec_low.eval(), places=5) + self.assertAlmostEqual(0.0, prec_high.eval(), places=5) + self.assertAlmostEqual(1.0, rec_low.eval(), places=5) + self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + + def testWeights2d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds, weights=weights) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds, weights=weights) + + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(1.0, prec_low.eval(), places=5) + self.assertAlmostEqual(0.0, prec_high.eval(), places=5) + self.assertAlmostEqual(1.0, rec_low.eval(), places=5) + self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + + def testExtremeThresholds(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) + thresholds = [-1.0, 2.0] # lower/higher than any values + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0.75, prec_low.eval()) + self.assertAlmostEqual(0.0, prec_high.eval()) + self.assertAlmostEqual(1.0, rec_low.eval()) + self.assertAlmostEqual(0.0, rec_high.eval()) + + def testZeroLabelsPredictions(self): + with self.test_session() as sess: + predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) + labels = array_ops.zeros([4]) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0, prec.eval(), 6) + self.assertAlmostEqual(0, rec.eval(), 6) + + def testWithMultipleUpdates(self): + num_samples = 1000 + batch_size = 10 + num_batches = int(num_samples / batch_size) + + # Create the labels and data. + labels = np.random.randint(0, 2, size=(num_samples, 1)) + noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) + predictions = 0.4 + 0.2 * labels + noise + predictions[predictions > 1] = 1 + predictions[predictions < 0] = 0 + thresholds = [0.3] + + tp = 0 + fp = 0 + fn = 0 + tn = 0 + for i in range(num_samples): + if predictions[i] > thresholds[0]: + if labels[i] == 1: + tp += 1 + else: + fp += 1 + else: + if labels[i] == 1: + fn += 1 + else: + tn += 1 + epsilon = 1e-7 + expected_prec = tp / (epsilon + tp + fp) + expected_rec = tp / (epsilon + tp + fn) + + labels = labels.astype(np.float32) + predictions = predictions.astype(np.float32) + + with self.test_session() as sess: + # Reshape the data so its easy to queue up: + predictions_batches = predictions.reshape((batch_size, num_batches)) + labels_batches = labels.reshape((batch_size, num_batches)) + + # Enqueue the data: + predictions_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + labels_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + + for i in range(int(num_batches)): + tf_prediction = constant_op.constant(predictions_batches[:, i]) + tf_label = constant_op.constant(labels_batches[:, i]) + sess.run([ + predictions_queue.enqueue(tf_prediction), + labels_queue.enqueue(tf_label) + ]) + + tf_predictions = predictions_queue.dequeue() + tf_labels = labels_queue.dequeue() + + prec, prec_op = metrics.streaming_precision_at_thresholds(tf_predictions, + tf_labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(tf_predictions, + tf_labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + for _ in range(int(num_samples / batch_size)): + sess.run([prec_op, rec_op]) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(expected_prec, prec.eval(), 2) + self.assertAlmostEqual(expected_rec, rec.eval(), 2) + + +class StreamingFPRThresholdsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_false_positive_rate_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0]) + _assert_local_variables(self, ( + 'false_positive_rate_at_thresholds/false_positives:0', + 'false_positive_rate_at_thresholds/true_negatives:0',)) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + fpr, _ = metrics.streaming_false_positive_rate_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [fpr]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + updates_collections=[my_collection_name]) + self.assertListEqual( + ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) + thresholds = [0, 0.5, 1.0] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(fpr_op) + + # Then verify idempotency. + initial_fpr = fpr.eval() + for _ in range(10): + self.assertAllClose(initial_fpr, fpr.eval()) + + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(inputs) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertEqual(0, fpr.eval()) + + def testSomeCorrect(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(0.5, fpr.eval()) + + def testAllIncorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(1, fpr.eval()) + + def testWeights1d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds, weights=weights) + + fpr_low = fpr[0] + fpr_high = fpr[1] + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(0.0, fpr_low.eval(), places=5) + self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) + + def testWeights2d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds, weights=weights) + + fpr_low = fpr[0] + fpr_high = fpr[1] + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(0.0, fpr_low.eval(), places=5) + self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) + + def testExtremeThresholds(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) + thresholds = [-1.0, 2.0] # lower/higher than any values + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + fpr_low = fpr[0] + fpr_high = fpr[1] + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(1.0, fpr_low.eval(), places=5) + self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) + + def testZeroLabelsPredictions(self): + with self.test_session() as sess: + predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) + labels = array_ops.zeros([4]) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(0, fpr.eval(), 6) + + def testWithMultipleUpdates(self): + num_samples = 1000 + batch_size = 10 + num_batches = int(num_samples / batch_size) + + # Create the labels and data. + labels = np.random.randint(0, 2, size=(num_samples, 1)) + noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) + predictions = 0.4 + 0.2 * labels + noise + predictions[predictions > 1] = 1 + predictions[predictions < 0] = 0 + thresholds = [0.3] + + fp = 0 + tn = 0 + for i in range(num_samples): + if predictions[i] > thresholds[0]: + if labels[i] == 0: + fp += 1 + else: + if labels[i] == 0: + tn += 1 + epsilon = 1e-7 + expected_fpr = fp / (epsilon + fp + tn) + + labels = labels.astype(np.float32) + predictions = predictions.astype(np.float32) + + with self.test_session() as sess: + # Reshape the data so its easy to queue up: + predictions_batches = predictions.reshape((batch_size, num_batches)) + labels_batches = labels.reshape((batch_size, num_batches)) + + # Enqueue the data: + predictions_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + labels_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + + for i in range(int(num_batches)): + tf_prediction = constant_op.constant(predictions_batches[:, i]) + tf_label = constant_op.constant(labels_batches[:, i]) + sess.run([ + predictions_queue.enqueue(tf_prediction), + labels_queue.enqueue(tf_label) + ]) + + tf_predictions = predictions_queue.dequeue() + tf_labels = labels_queue.dequeue() + + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + tf_predictions, tf_labels, thresholds) + + sess.run(variables.local_variables_initializer()) + for _ in range(int(num_samples / batch_size)): + sess.run(fpr_op) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(expected_fpr, fpr.eval(), 2) + + +class StreamingFNRThresholdsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_false_negative_rate_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0]) + _assert_local_variables(self, ( + 'false_negative_rate_at_thresholds/false_negatives:0', + 'false_negative_rate_at_thresholds/true_positives:0',)) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + fnr, _ = metrics.streaming_false_negative_rate_at_thresholds( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), thresholds=[0, 0.5, 1.0], metrics_collections=[my_collection_name]) - self.assertListEqual(ops.get_collection(my_collection_name), [prec, rec]) + self.assertListEqual(ops.get_collection(my_collection_name), [fnr]) def testUpdatesCollection(self): my_collection_name = '__updates__' - _, precision_op = metrics.streaming_precision_at_thresholds( - predictions=array_ops.ones((10, 1)), - labels=array_ops.ones((10, 1)), - thresholds=[0, 0.5, 1.0], - updates_collections=[my_collection_name]) - _, recall_op = metrics.streaming_recall_at_thresholds( + _, update_op = metrics.streaming_false_negative_rate_at_thresholds( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), thresholds=[0, 0.5, 1.0], updates_collections=[my_collection_name]) self.assertListEqual( - ops.get_collection(my_collection_name), [precision_op, recall_op]) + ops.get_collection(my_collection_name), [update_op]) def testValueTensorIsIdempotent(self): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) - # Run several updates, then verify idempotency. - sess.run([prec_op, rec_op]) - initial_prec = prec.eval() - initial_rec = rec.eval() + # Run several updates. for _ in range(10): - sess.run([prec_op, rec_op]) - self.assertAllClose(initial_prec, prec.eval()) - self.assertAllClose(initial_rec, rec.eval()) + sess.run(fnr_op) + + # Then verify idempotency. + initial_fnr = fnr.eval() + for _ in range(10): + self.assertAllClose(initial_fnr, fnr.eval()) - # TODO(nsilberman): fix tests (passing but incorrect). def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -2043,17 +2976,13 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertEqual(1, prec.eval()) - self.assertEqual(1, rec.eval()) + self.assertEqual(0, fnr.eval()) def testSomeCorrect(self): with self.test_session() as sess: @@ -2061,17 +2990,13 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0.5, prec.eval()) - self.assertAlmostEqual(0.5, rec.eval()) + self.assertAlmostEqual(0.5, fnr.eval()) def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -2080,17 +3005,13 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0, prec.eval()) - self.assertAlmostEqual(0, rec.eval()) + self.assertAlmostEqual(1, fnr.eval()) def testWeights1d(self): with self.test_session() as sess: @@ -2100,27 +3021,17 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): weights = constant_op.constant( [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32) thresholds = [0.5, 1.1] - prec, prec_op = metrics.streaming_precision_at_thresholds( - predictions, labels, thresholds, weights=weights) - rec, rec_op = metrics.streaming_recall_at_thresholds( + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + fnr_low = fnr[0] + fnr_high = fnr[1] sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(1.0, prec_low.eval(), places=5) - self.assertAlmostEqual(0.0, prec_high.eval(), places=5) - self.assertAlmostEqual(1.0, rec_low.eval(), places=5) - self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + self.assertAlmostEqual(0.0, fnr_low.eval(), places=5) + self.assertAlmostEqual(1.0, fnr_high.eval(), places=5) def testWeights2d(self): with self.test_session() as sess: @@ -2130,27 +3041,17 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): weights = constant_op.constant( [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32) thresholds = [0.5, 1.1] - prec, prec_op = metrics.streaming_precision_at_thresholds( - predictions, labels, thresholds, weights=weights) - rec, rec_op = metrics.streaming_recall_at_thresholds( + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + fnr_low = fnr[0] + fnr_high = fnr[1] sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(1.0, prec_low.eval(), places=5) - self.assertAlmostEqual(0.0, prec_high.eval(), places=5) - self.assertAlmostEqual(1.0, rec_low.eval(), places=5) - self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + self.assertAlmostEqual(0.0, fnr_low.eval(), places=5) + self.assertAlmostEqual(1.0, fnr_high.eval(), places=5) def testExtremeThresholds(self): with self.test_session() as sess: @@ -2158,41 +3059,30 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) thresholds = [-1.0, 2.0] # lower/higher than any values - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) + fnr_low = fnr[0] + fnr_high = fnr[1] sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0.75, prec_low.eval()) - self.assertAlmostEqual(0.0, prec_high.eval()) - self.assertAlmostEqual(1.0, rec_low.eval()) - self.assertAlmostEqual(0.0, rec_high.eval()) + self.assertAlmostEqual(0.0, fnr_low.eval()) + self.assertAlmostEqual(1.0, fnr_high.eval()) def testZeroLabelsPredictions(self): with self.test_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0, prec.eval(), 6) - self.assertAlmostEqual(0, rec.eval(), 6) + self.assertAlmostEqual(0, fnr.eval(), 6) def testWithMultipleUpdates(self): num_samples = 1000 @@ -2207,24 +3097,17 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions[predictions < 0] = 0 thresholds = [0.3] - tp = 0 - fp = 0 fn = 0 - tn = 0 + tp = 0 for i in range(num_samples): if predictions[i] > thresholds[0]: if labels[i] == 1: tp += 1 - else: - fp += 1 else: if labels[i] == 1: fn += 1 - else: - tn += 1 epsilon = 1e-7 - expected_prec = tp / (epsilon + tp + fp) - expected_rec = tp / (epsilon + tp + fn) + expected_fnr = fn / (epsilon + fn + tp) labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) @@ -2251,21 +3134,16 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): tf_predictions = predictions_queue.dequeue() tf_labels = labels_queue.dequeue() - prec, prec_op = metrics.streaming_precision_at_thresholds(tf_predictions, - tf_labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(tf_predictions, - tf_labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + tf_predictions, tf_labels, thresholds) sess.run(variables.local_variables_initializer()) for _ in range(int(num_samples / batch_size)): - sess.run([prec_op, rec_op]) + sess.run(fnr_op) # Since this is only approximate, we can't expect a 6 digits match. # Although with higher number of samples/thresholds we should see the # accuracy improving - self.assertAlmostEqual(expected_prec, prec.eval(), 2) - self.assertAlmostEqual(expected_rec, rec.eval(), 2) + self.assertAlmostEqual(expected_fnr, fnr.eval(), 2) # TODO(ptucker): Remove when we remove `streaming_recall_at_k`. @@ -4978,7 +5856,7 @@ class StreamingMeanIOUTest(test.TestCase): sess.run(variables.local_variables_initializer()) for _ in range(5): sess.run(update_op) - desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0, 0.]) + desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0]) self.assertAlmostEqual(desired_output, miou.eval()) def testUpdateOpEvalIsAccumulatedConfusionMatrix(self): @@ -5060,6 +5938,58 @@ class StreamingMeanIOUTest(test.TestCase): desired_miou = np.mean([2. / 4., 4. / 6.]) self.assertAlmostEqual(desired_miou, miou.eval()) + def testMissingClassInLabels(self): + labels = constant_op.constant([ + [[0, 0, 1, 1, 0, 0], + [1, 0, 0, 0, 0, 1]], + [[1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]]]) + predictions = constant_op.constant([ + [[0, 0, 2, 1, 1, 0], + [0, 1, 2, 2, 0, 1]], + [[0, 0, 2, 1, 1, 1], + [1, 1, 2, 0, 0, 0]]]) + num_classes = 3 + with self.test_session() as sess: + miou, update_op = metrics.streaming_mean_iou( + predictions, labels, num_classes) + sess.run(variables.local_variables_initializer()) + self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval()) + self.assertAlmostEqual( + 1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)), + miou.eval()) + + def testMissingClassOverallSmall(self): + labels = constant_op.constant([0]) + predictions = constant_op.constant([0]) + num_classes = 2 + with self.test_session() as sess: + miou, update_op = metrics.streaming_mean_iou( + predictions, labels, num_classes) + sess.run(variables.local_variables_initializer()) + self.assertAllEqual([[1, 0], [0, 0]], update_op.eval()) + self.assertAlmostEqual(1, miou.eval()) + + def testMissingClassOverallLarge(self): + labels = constant_op.constant([ + [[0, 0, 1, 1, 0, 0], + [1, 0, 0, 0, 0, 1]], + [[1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]]]) + predictions = constant_op.constant([ + [[0, 0, 1, 1, 0, 0], + [1, 1, 0, 0, 1, 1]], + [[0, 0, 0, 1, 1, 1], + [1, 1, 1, 0, 0, 0]]]) + num_classes = 3 + with self.test_session() as sess: + miou, update_op = metrics.streaming_mean_iou( + predictions, labels, num_classes) + sess.run(variables.local_variables_initializer()) + self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval()) + self.assertAlmostEqual( + 1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), miou.eval()) + class StreamingConcatTest(test.TestCase): @@ -5240,5 +6170,163 @@ class AggregateMetricMapTest(test.TestCase): self.assertEqual(4, names_to_values['m2'].eval()) +class CountTest(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + + def testVars(self): + metrics.count(array_ops.ones([4, 3])) + _assert_local_variables(self, ['count/count:0']) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + mean, _ = metrics.count( + array_ops.ones([4, 3]), metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [mean]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.count( + array_ops.ones([4, 3]), updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testBasic(self): + with self.test_session() as sess: + values_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) + _enqueue_vector(sess, values_queue, [0, 1]) + _enqueue_vector(sess, values_queue, [-4.2, 9.1]) + _enqueue_vector(sess, values_queue, [6.5, 0]) + _enqueue_vector(sess, values_queue, [-3.2, 4.0]) + values = values_queue.dequeue() + + result, update_op = metrics.count(values) + + sess.run(variables.local_variables_initializer()) + for _ in range(4): + sess.run(update_op) + self.assertAlmostEqual(8.0, sess.run(result), 5) + + def testUpdateOpsReturnsCurrentValue(self): + with self.test_session() as sess: + values_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) + _enqueue_vector(sess, values_queue, [0, 1]) + _enqueue_vector(sess, values_queue, [-4.2, 9.1]) + _enqueue_vector(sess, values_queue, [6.5, 0]) + _enqueue_vector(sess, values_queue, [-3.2, 4.0]) + values = values_queue.dequeue() + + result, update_op = metrics.count(values) + + sess.run(variables.local_variables_initializer()) + + self.assertAlmostEqual(2.0, sess.run(update_op), 5) + self.assertAlmostEqual(4.0, sess.run(update_op), 5) + self.assertAlmostEqual(6.0, sess.run(update_op), 5) + self.assertAlmostEqual(8.0, sess.run(update_op), 5) + + self.assertAlmostEqual(8.0, sess.run(result), 5) + + def test1dWeightedValues(self): + with self.test_session() as sess: + # Create the queue that populates the values. + values_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) + _enqueue_vector(sess, values_queue, [0, 1]) + _enqueue_vector(sess, values_queue, [-4.2, 9.1]) + _enqueue_vector(sess, values_queue, [6.5, 0]) + _enqueue_vector(sess, values_queue, [-3.2, 4.0]) + values = values_queue.dequeue() + + # Create the queue that populates the weighted labels. + weights_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) + _enqueue_vector(sess, weights_queue, [0.5]) + _enqueue_vector(sess, weights_queue, [0]) + _enqueue_vector(sess, weights_queue, [0]) + _enqueue_vector(sess, weights_queue, [1.2]) + weights = weights_queue.dequeue() + + result, update_op = metrics.count(values, weights) + + variables.local_variables_initializer().run() + for _ in range(4): + update_op.eval() + self.assertAlmostEqual(3.4, result.eval(), 5) + + def test1dWeightedValues_placeholders(self): + with self.test_session() as sess: + # Create the queue that populates the values. + feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) + values = array_ops.placeholder(dtype=dtypes_lib.float32) + + # Create the queue that populates the weighted labels. + weights_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1,)) + _enqueue_vector(sess, weights_queue, 0.5, shape=(1,)) + _enqueue_vector(sess, weights_queue, 0, shape=(1,)) + _enqueue_vector(sess, weights_queue, 0, shape=(1,)) + _enqueue_vector(sess, weights_queue, 1.2, shape=(1,)) + weights = weights_queue.dequeue() + + result, update_op = metrics.count(values, weights) + + variables.local_variables_initializer().run() + for i in range(4): + update_op.eval(feed_dict={values: feed_values[i]}) + self.assertAlmostEqual(3.4, result.eval(), 5) + + def test2dWeightedValues(self): + with self.test_session() as sess: + # Create the queue that populates the values. + values_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) + _enqueue_vector(sess, values_queue, [0, 1]) + _enqueue_vector(sess, values_queue, [-4.2, 9.1]) + _enqueue_vector(sess, values_queue, [6.5, 0]) + _enqueue_vector(sess, values_queue, [-3.2, 4.0]) + values = values_queue.dequeue() + + # Create the queue that populates the weighted labels. + weights_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) + _enqueue_vector(sess, weights_queue, [1.1, 1]) + _enqueue_vector(sess, weights_queue, [1, 0]) + _enqueue_vector(sess, weights_queue, [0, 1]) + _enqueue_vector(sess, weights_queue, [0, 0]) + weights = weights_queue.dequeue() + + result, update_op = metrics.count(values, weights) + + variables.local_variables_initializer().run() + for _ in range(4): + update_op.eval() + self.assertAlmostEqual(4.1, result.eval(), 5) + + def test2dWeightedValues_placeholders(self): + with self.test_session() as sess: + # Create the queue that populates the values. + feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) + values = array_ops.placeholder(dtype=dtypes_lib.float32) + + # Create the queue that populates the weighted labels. + weights_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(2,)) + _enqueue_vector(sess, weights_queue, [1.1, 1], shape=(2,)) + _enqueue_vector(sess, weights_queue, [1, 0], shape=(2,)) + _enqueue_vector(sess, weights_queue, [0, 1], shape=(2,)) + _enqueue_vector(sess, weights_queue, [0, 0], shape=(2,)) + weights = weights_queue.dequeue() + + result, update_op = metrics.count(values, weights) + + variables.local_variables_initializer().run() + for i in range(4): + update_op.eval(feed_dict={values: feed_values[i]}) + self.assertAlmostEqual(4.1, result.eval(), 5) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/mpi_collectives/__init__.py b/tensorflow/contrib/mpi_collectives/__init__.py index b94f7b0a353c4c3c698a927d8718bb5b490872cb..9ed16a6f078a506b60fd14f4356ff65a0a692203 100644 --- a/tensorflow/contrib/mpi_collectives/__init__.py +++ b/tensorflow/contrib/mpi_collectives/__init__.py @@ -194,7 +194,7 @@ class DistributedOptimizer(tf.train.Optimizer): See Optimizer.compute_gradients() for more info. - In DistributedOptimizer, compute_gradients() is overriden to also + In DistributedOptimizer, compute_gradients() is overridden to also allreduce the gradients before returning them. """ gradients = (super(DistributedOptimizer, self) diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index d6508362b8bf01468a43b26d6a0d0c9807b5967e..3aa3215a5fd51116b8443b1123c6ce5ea1f573c0 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -71,10 +71,14 @@ tf_kernel_library( "kernels/nccl_manager.cc", "kernels/nccl_manager.h", "kernels/nccl_ops.cc", + "kernels/nccl_rewrite.cc", ], deps = [ + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:proto_text", "@nccl_archive//:nccl", ], alwayslink = 1, diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc index 4eb52492dbcc386941029709631314634c1c9be1..266d4f6f0de0274dca2bfc9022503f09b0ca7d42 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_ops.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc @@ -15,8 +15,6 @@ limitations under the License. #if GOOGLE_CUDA -#include -#include #include #include "src/nccl.h" @@ -24,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace { // Base class for all communicator ops that use nccl. // @@ -134,7 +133,7 @@ class NcclReduceSendKernel : public NcclReduceOpBase { compute_stream, &c->input(0), std::move(actual_done)); } }; -REGISTER_KERNEL_BUILDER(Name("NcclReduceSend").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclReduceSend").Device(DEVICE_GPU), NcclReduceSendKernel); // To execute a single reduce, this kernel is called once for one devices, and @@ -166,7 +165,7 @@ class NcclReduceRecvKernel : public NcclReduceOpBase { private: ncclRedOp_t reduction_op_; }; -REGISTER_KERNEL_BUILDER(Name("NcclReduceRecv").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclReduceRecv").Device(DEVICE_GPU), NcclReduceRecvKernel); // To execute a single broadcast, this kernel is called once for one device, and @@ -191,7 +190,7 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase { std::move(actual_done)); } }; -REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclBroadcastSend").Device(DEVICE_GPU), NcclBroadcastSendKernel); // To execute a single broadcast, this kernel is called once for all but one of @@ -206,7 +205,7 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase { const Tensor& shape_t = c->input(0); TensorShape shape; OP_REQUIRES_OK_ASYNC( - c, TensorShapeUtils::MakeShape(shape_t.vec(), &shape), done); + c, TensorShapeUtils::MakeShape(shape_t.vec(), &shape), done); Tensor* out_t; OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape, &out_t), done); @@ -224,9 +223,24 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase { } }; REGISTER_KERNEL_BUILDER( - Name("NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"), + Name("_NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"), NcclBroadcastRecvKernel); +// Define stub kernels for the ops that get replaced post placement. +class NcclStubKernel : public AsyncOpKernel { + public: + explicit NcclStubKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {} + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + c->SetStatus(errors::Unimplemented( + "This op should be replaced during graph optimization.")); + done(); + } +}; +REGISTER_KERNEL_BUILDER(Name("NcclBroadcast").Device(DEVICE_GPU), + NcclStubKernel); +REGISTER_KERNEL_BUILDER(Name("NcclReduce").Device(DEVICE_GPU), NcclStubKernel); + +} // namespace } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4de46a93fab1dfe93b47f2789cc533bc447e43a --- /dev/null +++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc @@ -0,0 +1,276 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include +#include + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/node_builder.h" + +namespace tensorflow { +namespace { + +// Replaces NcclReduce node with _NcclReduceRecv reusing one input of same +// device, adds one _NcclReduceSend for each other input. +Status ReplaceReduce(Graph* graph, Node* node) { + string reduction; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "reduction", &reduction)); + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype)); + int num_devices = node->num_inputs(); + string shared_name = node->name(); + auto make_builder = [&](StringPiece op_name, StringPiece suffix) { + return NodeBuilder(strings::StrCat(shared_name, suffix), op_name) + .Attr("reduction", reduction) + .Attr("num_devices", num_devices) + .Attr("shared_name", shared_name) + .Attr("T", dtype); + }; + std::vector control_inputs; + for (const auto& edge : node->in_edges()) { + if (edge->IsControlEdge()) { + control_inputs.push_back(edge->src()); + } + } + std::vector out_nodes; + for (const auto& edge : node->out_edges()) { + out_nodes.emplace_back(edge->dst(), edge->dst_input()); + } + int recv_dev = node->assigned_device_name_index(); + NodeBuilder recv_builder = + make_builder("_NcclReduceRecv", "Recv").ControlInputs(control_inputs); + bool recv_input_set = false; + int send_counter = 0; + for (const auto& edge : node->in_edges()) { + Node* src_node = edge->src(); + if (edge->IsControlEdge()) { + continue; + } + int send_dev = src_node->assigned_device_name_index(); + if (!recv_input_set && send_dev == recv_dev) { + recv_builder.Input(src_node); + recv_input_set = true; + continue; + } + auto send_builder = make_builder("_NcclReduceSend", + strings::StrCat("Send_", ++send_counter)) + .Input(src_node) + .ControlInputs(control_inputs); + Node* send_node = nullptr; + TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node)); + send_node->set_assigned_device_name_index(send_dev); + // Send nodes don't have any outputs and therefore have no data dependencies + // to the outputs of the graph. We add a control dependency to the receive + // node so that those 'dangling' nodes are run. + // TODO(b/67027412): Avoid these cross-device control edges. + for (const auto& out_node : out_nodes) { + graph->AddControlEdge(send_node, out_node.node); + } + } + if (!recv_input_set) { + return errors::InvalidArgument( + "No input tensor uses the same device as the NcclReduce op"); + } + Node* recv_node = nullptr; + TF_RETURN_IF_ERROR(recv_builder.Finalize(graph, &recv_node)); + recv_node->set_assigned_device_name_index(recv_dev); + graph->RemoveNode(node); + for (const auto& out_node : out_nodes) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(recv_node, out_node.node); + } else { + graph->AddEdge(recv_node, 0, out_node.node, out_node.index); + } + } + return Status::OK(); +} + +TensorProto TensorFromShape(const TensorShapeProto& shape) { + TensorProto result; + result.set_dtype(DT_INT32); + for (const auto& dim : shape.dim()) { + result.add_int_val(dim.size()); + } + result.mutable_tensor_shape()->add_dim()->set_size(shape.dim_size()); + return result; +} + +// Replaces NcclBroadcast node with _NcclBroadcastSend, connects the input to +// all outputs of same device, adds one _NcclBroadcastRecv for each other output +// device. +Status ReplaceBroadcast(Graph* graph, Node* node) { + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype)); + int send_dev = node->assigned_device_name_index(); + int num_devices = 0; // Number of distinct devices, incremented below. + std::vector recv_index_map; // Map device name index to stable index. + + // Map device name index to nodes that take the broadcast as input. + std::vector> out_nodes_map; + for (const auto& edge : node->out_edges()) { + int dst_dev = edge->IsControlEdge() + ? send_dev + : edge->dst()->assigned_device_name_index(); + if (out_nodes_map.size() <= dst_dev) { + out_nodes_map.resize(dst_dev + 1); + recv_index_map.resize(dst_dev + 1); + } + auto it = out_nodes_map.begin() + dst_dev; + if (it->empty()) { + recv_index_map[dst_dev] = num_devices; + ++num_devices; + } + it->emplace_front(NodeBuilder::NodeOut(edge->dst(), edge->dst_input())); + } + + if (num_devices <= 1) { + // Only one participating device, skip NCCL op. + const Edge* in_edge = nullptr; + TF_RETURN_IF_ERROR(node->input_edge(0, &in_edge)); + Node* in_node = in_edge->src(); + int in_index = in_edge->src_output(); + graph->RemoveNode(node); + for (const auto& out_nodes : out_nodes_map) { + for (const auto& out_node : out_nodes) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(in_node, out_node.node); + } else { + graph->AddEdge(in_node, in_index, out_node.node, out_node.index); + } + } + } + return Status::OK(); + } + + string shared_name = node->name(); + auto make_builder = [&](StringPiece op_name, StringPiece suffix) { + return NodeBuilder(strings::StrCat(shared_name, suffix), op_name) + .Attr("num_devices", num_devices) + .Attr("shared_name", shared_name) + .Attr("T", dtype); + }; + + // Create broadcast send node and replace the original broadcast node. + NodeBuilder::NodeOut in_node; + NodeBuilder send_builder = make_builder("_NcclBroadcastSend", "Send"); + for (const auto& edge : node->in_edges()) { + if (edge->IsControlEdge()) { + send_builder.ControlInput(edge->src()); + } else { + in_node = NodeBuilder::NodeOut(edge->src(), edge->src_output()); + send_builder.Input(in_node); + } + } + Node* send_node = nullptr; + TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node)); + send_node->set_assigned_device_name_index(send_dev); + + TensorShapeProto shape_proto; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "shape", &shape_proto)); + + // Delete the original node before reconnecting to outputs. + graph->RemoveNode(node); + + // Connect all outputs on the device of broadcast send. + for (const auto& out_node : out_nodes_map[send_dev]) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(send_node, out_node.node); + } else { + graph->AddEdge(in_node.node, in_node.index, out_node.node, + out_node.index); + // Add control edge so send node is run. + graph->AddControlEdge(send_node, out_node.node); + } + } + out_nodes_map[send_dev].clear(); + + TensorProto tensor_proto = TensorFromShape(shape_proto); + bool is_fully_defined = TensorShape(shape_proto).IsFullyDefined(); + string shape_name = strings::StrCat(in_node.node->name(), "/Shape"); + Node* shape_node = nullptr; + if (!is_fully_defined) { + NodeBuilder shape_builder(shape_name, "Shape"); + shape_builder.Input(in_node).Attr("out_type", DT_INT32).Attr("T", dtype); + TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node)); + shape_node->set_assigned_device_name_index(send_dev); + } + + // For all other devices, create a broadcast receive and connect outputs. + for (int recv_dev = 0; recv_dev < out_nodes_map.size(); ++recv_dev) { + if (out_nodes_map[recv_dev].empty()) { + continue; + } + int recv_index = recv_index_map[recv_dev]; + if (is_fully_defined) { + // If the shape is fully defined, define one const node per device. + NodeBuilder shape_builder(strings::StrCat(shape_name, recv_index), + "Const"); + shape_builder.Attr("value", tensor_proto).Attr("dtype", DT_INT32); + TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node)); + shape_node->set_assigned_device_name_index(recv_dev); + } + Node* recv_node; + TF_RETURN_IF_ERROR( + make_builder("_NcclBroadcastRecv", strings::StrCat("Recv_", recv_index)) + .Input(shape_node) + .Finalize(graph, &recv_node)); + recv_node->set_assigned_device_name_index(recv_dev); + for (const auto& out_node : out_nodes_map[recv_dev]) { + graph->AddEdge(recv_node, 0, out_node.node, out_node.index); + } + } + + return Status::OK(); +} + +// Replaces occurrences of Nccl{Reduce, Broadcast}Input/Output with their +// _Nccl...Send/Recv counterparts and removes data dependencies between them. +class NcclReplacePass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override { + if (options.graph == nullptr) { + return Status::OK(); + } + Graph* graph = options.graph->get(); + if (graph == nullptr) { + return errors::Internal( + "NCCL replacement should happen before partitioning and a " + "graph should be available."); + } + // Find reduction and broadcast ops and replace them with Send/Recv ops. + for (Node* node : graph->op_nodes()) { + StringPiece type = node->type_string(); + if (!type.starts_with("Nccl")) { + continue; + } + if (type == "NcclReduce") { + TF_RETURN_IF_ERROR(ReplaceReduce(graph, node)); + } + if (type == "NcclBroadcast") { + TF_RETURN_IF_ERROR(ReplaceBroadcast(graph, node)); + } + } + return Status::OK(); + } +}; +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0, + NcclReplacePass); + +} // namespace +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc index 532c79c24cc9596af580ee3faf463aecbc59bb07..8eb804c2e988f313ba1b340217cae20f1f5502c7 100644 --- a/tensorflow/contrib/nccl/ops/nccl_ops.cc +++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc @@ -45,7 +45,28 @@ num_devices: The number of devices participating in this reduction. shared_name: Identifier that shared between ops of the same reduction. )doc"); -REGISTER_OP("NcclReduceSend") +// Note: This op has no kernel implementation, but is replaced by +// _NcclReduceSend and _NcclReduceRecv during graph optimization stage. +REGISTER_OP("NcclReduce") + .Input("input: num_devices * T") + .Output("data: T") + .Attr("reduction: {'min', 'max', 'prod', 'sum'}") + .Attr("T: {float, float64, int32, int64}") + .Attr("num_devices: int") + .SetIsStateful() + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Reduces `input` from `num_devices` using `reduction` to a single device. + +The graph should be constructed so that all inputs have a valid device +assignment, and the op itself is assigned one of these devices. + +input: The input to the reduction. +data: the value of the reduction across all `num_devices` devices. +reduction: the reduction operation to perform. + )doc"); + +REGISTER_OP("_NcclReduceSend") .Input("input: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") .Attr("T: {float, float64, int32, int64}") @@ -54,19 +75,20 @@ REGISTER_OP("NcclReduceSend") .SetIsStateful() .SetShapeFn(shape_inference::NoOutputs) .Doc(R"doc( -Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`. +Replacement node for NcclReduce. +Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`. The graph should be constructed so that 'num_devices-1' devices run -`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value +`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the reduction +input: The input to the reduction. reduction: the reduction operation to perform. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same reduce. )doc"); -REGISTER_OP("NcclReduceRecv") +REGISTER_OP("_NcclReduceRecv") .Input("input: T") .Output("data: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") @@ -76,21 +98,42 @@ REGISTER_OP("NcclReduceRecv") .SetIsStateful() .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( +Replacement node for NcclReduce. + Reduces 'input' from this op and the NcclReduceSend ops registered in the same `shared_name`. - The graph should be constructed so that 'num_devices-1' devices run -`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value +`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the reduction +input: The input to the reduction. data: The reduced data received from this op and the NcclReduceSend op. reduction: the reduction operation to perform. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same reduce. )doc"); -REGISTER_OP("NcclBroadcastSend") +// Note: This op has no kernel implementation, but is replaced by +// _NcclBroadcastSend and _NcclBroadcastRecv during graph optimization stage. +REGISTER_OP("NcclBroadcast") + .Input("input: T") + .Output("output: T") + .Attr("T: {float, float64, int32, int64}") + .Attr("shape: shape") + .SetIsStateful() + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Sends `input` to all devices that are connected to the output. + +The graph should be constructed so that all ops connected to the output have a +valid device assignment, and the op itself is assigned one of these devices. + +input: The input to the broadcast. +output: The same as input. +shape: The shape of the input tensor. + )doc"); + +REGISTER_OP("_NcclBroadcastSend") .Input("input: T") .Attr("T: {float, float64, int32, int64}") .Attr("num_devices: int") @@ -98,19 +141,21 @@ REGISTER_OP("NcclBroadcastSend") .SetIsStateful() .SetShapeFn(shape_inference::NoOutputs) .Doc(R"doc( -Sends `input` to the NcclBroadcastRecv ops registered in the same `shared_name`. +Replacement node for NcclBroadcast. -The graph should be constructed so that one device runs `NcclBroadcastSend` and -`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`. +Sends `input` to the _NcclBroadcastRecv ops registered in the same +`shared_name`. +The graph should be constructed so that one device runs `_NcclBroadcastSend` and +`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the broadcast +input: The input to the broadcast. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same broadcast. )doc"); -REGISTER_OP("NcclBroadcastRecv") - .Input("shape: int64") +REGISTER_OP("_NcclBroadcastRecv") + .Input("shape: int32") .Output("output: T") .Attr("T: {float, float64, int32, int64}") .Attr("num_devices: int") @@ -123,11 +168,12 @@ REGISTER_OP("NcclBroadcastRecv") return Status::OK(); }) .Doc(R"doc( -Sends data of shape `shape` from the NcclBroadcastSend op registered in the -same `shared_name`. +Replacement node for NcclBroadcast. -The graph should be constructed so that one device runs `NcclBroadcastSend` and -`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`. +Sends data of shape `shape` from the _NcclBroadcastSend op registered in the +same `shared_name`. +The graph should be constructed so that one device runs `_NcclBroadcastSend` and +`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. shape: The shape of the output. diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index 906d9f948acf212dce1dbbbf9ec7c60c30f389b1..8dc038b9ac992de7db8b762e3697c6693099e192 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -23,9 +23,7 @@ from tensorflow.contrib.nccl.ops import gen_nccl_ops from tensorflow.contrib.util import loader from tensorflow.python.eager import context from tensorflow.python.framework import device -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.platform import resource_loader _nccl_ops_so = loader.load_op_library( @@ -64,13 +62,13 @@ def _all_sum_grad(op, grad): LookupError: If `reduction` is not `sum`. """ if op.get_attr('reduction') != 'sum': - raise LookupError('No gradient defined for NcclAllReduce except all_sum.') + raise LookupError('No gradient defined for NcclAllReduce except sum.') - _check_device_assignment(grad) + _check_device(grad, expected=op.device) num_devices = op.get_attr('num_devices') shared_name = op.get_attr('shared_name') + '_grad' - with ops.device(grad.device): + with ops.device(op.device): return gen_nccl_ops.nccl_all_reduce( input=grad, reduction='sum', @@ -129,7 +127,7 @@ def all_max(tensors): return _apply_all_reduce('max', tensors) -def reduce_sum(tensors, dst_device): +def reduce_sum(tensors): """Returns a tensor with the reduce sum across `tensors`. The computation is done with a reduce operation, so only one tensor is @@ -138,54 +136,76 @@ def reduce_sum(tensors, dst_device): Args: tensors: The input tensors across which to sum; must be assigned to GPU devices. - dst_device: The device of the returned tensor. Returns: - A tensor containing the sum of the input tensors, with the device of the - tensor being `dst_device`. + A tensor containing the sum of the input tensors. + + Raises: + LookupError: If context is not currently using a GPU device. + """ + return _apply_reduce('sum', tensors) + + +@ops.RegisterGradient('NcclReduce') +def _reduce_sum_grad(op, grad): + """The gradients for input `Operation` of `reduce_sum`. + + Args: + op: The `sum send` `Operation` that we are differentiating. + grad: Gradient with respect to the output of the `reduce_sum` op. + + Returns: + The gradient with respect to the input of `reduce_sum` op. + + Raises: + LookupError: If the reduction attribute of op is not `sum`. """ - return _apply_reduce('sum', tensors, dst_device) + if op.get_attr('reduction') != 'sum': + raise LookupError('No gradient defined for NcclReduce except sum.') + _check_device(grad, expected=op.device) + with ops.device(op.device): + result = gen_nccl_ops.nccl_broadcast(input=grad, shape=grad.shape) -def broadcast(src_tensor, dst_devices): - """Returns a list of tensors on `dst_devices`, each with value `tensor`. + return [result] * len(op.inputs) - The computation is done with a broadcast nccl operation, so if only some of - the returned tensors and src_tensor are evaluated then the computation will - hang. + +def broadcast(tensor): + """Returns a tensor that can be efficiently transferred to other devices. Args: - src_tensor: The tensor to send; must be assigned to a GPU device. - dst_devices: The GPU devices to receive the sent tensor. + tensor: The tensor to send; must be assigned to a GPU device. Returns: - An `Operation` to send the `src_tensor`, and a list of tensors, each with - the value of `src_tensor`, where the device of tensor i is `dst_devices[i]`. + A tensor with the value of `src_tensor`, which can be used as input to + ops on other GPU devices. """ - if not dst_devices: - raise ValueError('Must pass >0 dst_devices to broadcast') _check_graph_mode() - _check_device_assignment(src_tensor) + _check_device(tensor) - shape = array_ops.shape(src_tensor, out_type=dtypes.int64) - num_devices = len(dst_devices) + 1 - shared_name = _get_shared_name() + with ops.device(tensor.device): + return gen_nccl_ops.nccl_broadcast(input=tensor, shape=tensor.shape) - with ops.device(src_tensor.device): - send = gen_nccl_ops.nccl_broadcast_send( - input=src_tensor, num_devices=num_devices, shared_name=shared_name) - - recvs = [] - for d in dst_devices: - with ops.device(d): - recvs.append( - gen_nccl_ops.nccl_broadcast_recv( - shape=shape, - T=src_tensor.dtype, - num_devices=num_devices, - shared_name=shared_name)) - return send, recvs +@ops.RegisterGradient('NcclBroadcast') +def _broadcast_grad(op, accumulated_grad): + """The gradients for input `Operation` of `broadcast`. + + Args: + op: The `broadcast send` `Operation` that we are differentiating. + accumulated_grad: Accumulated gradients with respect to the output of the + `broadcast` op. + + Returns: + Gradients with respect to the input of `broadcast`. + """ + # Grab inputs of accumulated_grad and replace accumulation with reduce_sum. + grads = [t for t in accumulated_grad.op.inputs] + for t in grads: + _check_device(t) + + with ops.device(op.device): + return gen_nccl_ops.nccl_reduce(input=grads, reduction='sum') def _apply_all_reduce(reduction, tensors): @@ -198,7 +218,7 @@ def _apply_all_reduce(reduction, tensors): res = [] for t in tensors: - _check_device_assignment(t) + _check_device(t) with ops.device(t.device): res.append( gen_nccl_ops.nccl_all_reduce( @@ -210,40 +230,20 @@ def _apply_all_reduce(reduction, tensors): return res -def _apply_reduce(reduction, tensors, dst_device): +def _apply_reduce(reduction, tensors): """Helper function for reduce_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to reduce operations') - if not dst_device: - raise ValueError('Must pass dst_device to reduce operations') _check_graph_mode() + for t in tensors: + _check_device(t) + result = gen_nccl_ops.nccl_reduce(input=tensors, reduction=reduction) try: - recv_index = next(i for i, t in enumerate(tensors) - if t.device == dst_device) + next(t for t in tensors if t.device == result.device) except StopIteration: - raise ValueError('One of the tensors must be assigned to dst_device') - shared_name = _get_shared_name() - - sends = [] - for t in tensors[:recv_index] + tensors[recv_index + 1:]: - _check_device_assignment(t) - with ops.device(t.device): - sends.append( - gen_nccl_ops.nccl_reduce_send( - input=t, - reduction=reduction, - num_devices=len(tensors), - shared_name=shared_name)) - - with ops.device(dst_device): - recv = gen_nccl_ops.nccl_reduce_recv( - input=tensors[recv_index], - reduction=reduction, - num_devices=len(tensors), - shared_name=shared_name) - - return recv, sends + raise ValueError('One input tensor must be assigned to current device') + return result _lock = threading.Lock() @@ -259,9 +259,11 @@ def _get_shared_name(): return 'c%s' % val -def _check_device_assignment(tensor): +def _check_device(tensor, expected=None): if not device.canonical_name(tensor.device): raise ValueError('Device assignment required for nccl collective ops') + if expected and expected != tensor.device: + raise ValueError('Expected device %s, got %s' % (expected, tensor.device)) def _check_graph_mode(): diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py index 96d67723a0ad197436a12924bd2b4ecb73eee4cb..0b13e3595e36b609468f459d9179f8e9f5c1e055 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py @@ -22,8 +22,10 @@ from functools import partial import numpy as np from tensorflow.contrib import nccl +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients from tensorflow.python.platform import test @@ -36,27 +38,30 @@ def _DeviceTensors(tensors, devices): def _NcclAllReduce(nccl_fun, tensors, devices): - return nccl_fun(_DeviceTensors(tensors, devices)), [] + return nccl_fun(_DeviceTensors(tensors, devices)) def _NcclReduce(nccl_fun, tensors, devices): - d_tensors = _DeviceTensors(tensors, devices) receiver = np.random.randint(0, len(devices)) - received_tensor, send_ops = nccl_fun(d_tensors, devices[receiver]) - return [received_tensor], send_ops + with ops.device(devices[receiver]): + return [nccl_fun(_DeviceTensors(tensors, devices))] def _NcclBroadcast(tensors, devices): sender = np.random.randint(0, len(devices)) - d_tensor = _DeviceTensors(tensors[0:1], devices[sender:sender + 1])[0] - other_devices = devices[:sender] + devices[sender + 1:] - send_op, received_tensors = nccl.broadcast(d_tensor, other_devices) - return received_tensors, [send_op] + with ops.device(devices[sender]): + tensor = array_ops.identity(tensors[0]) + broadcast = nccl.broadcast(tensor) + return _DeviceTensors([broadcast] * len(devices), devices) class NcclTestCase(test.TestCase): - def _Test(self, nccl_reduce, numpy_fn): + def _Test(self, + nccl_reduce, + numpy_fn, + device_sets=(['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], + ['/device:GPU:1', '/device:GPU:0'])): """Tests that nccl_reduce does the same as reduction with numpy_fn. Args: @@ -65,6 +70,7 @@ class NcclTestCase(test.TestCase): reduction. numpy_fn: A function taking two tensors and returning the reduction of the two. + device_sets: Tuple of virtual devices to run test on. """ if not test.is_gpu_available(): return # Test requires access to a GPU @@ -74,26 +80,28 @@ class NcclTestCase(test.TestCase): # same communicator across multiple sessions. with self.test_session(use_gpu=True) as sess: - for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], - ['/device:GPU:1', '/device:GPU:0']]: + for devices in device_sets: shape = (3, 4) random = (np.random.random_sample(shape) - .5) * 1024 - tensors = [random.astype(dtype)] * len(devices) + tensors = [] + for _ in devices: + tensors.append(random.astype(dtype)) np_ans = tensors[0] for t in tensors[1:]: np_ans = numpy_fn(np_ans, t) - reduce_tensors, reduce_ops = nccl_reduce(tensors, devices) + reduce_tensors = nccl_reduce(tensors, devices) self.assertNotEmpty(reduce_tensors) # Test shape inference. for r in reduce_tensors: self.assertEqual(shape, r.get_shape()) + result_tensors = [array_ops.identity(t) for t in reduce_tensors] + # Test execution and results. - nccl_results = sess.run(reduce_tensors + reduce_ops) - for r in nccl_results[:len(reduce_tensors)]: - self.assertAllClose(r, np_ans) + for t in sess.run(result_tensors): + self.assertAllClose(t, np_ans) def _TestGradient(self, nccl_reduce, numpy_fn): """Tests the gradient of nccl_reduce. @@ -106,14 +114,12 @@ class NcclTestCase(test.TestCase): reduction of the two. """ def _Gradient(tensors, devices): - reduce_tensors, _ = nccl_reduce(tensors, devices) - tensor_ops = [t.op for t in reduce_tensors] - d_tensors = _DeviceTensors(tensors, devices) - grad_tensors = [ - ops.get_gradient_function(op)(op, loss) - for op, loss in zip(tensor_ops, d_tensors) - ] - return grad_tensors, [] + inputs = [array_ops.placeholder(t.dtype, t.shape) for t in tensors] + reduce_tensors = nccl_reduce(inputs, devices) + losses = _DeviceTensors(tensors, [t.device for t in reduce_tensors]) + grads = gradients.gradients( + reduce_tensors, inputs, losses, colocate_gradients_with_ops=True) + return [g for g in grads if g is not None] self._Test(_Gradient, numpy_fn) @@ -142,27 +148,40 @@ class SingleReduceTest(NcclTestCase): def testSum(self): self._Test(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x + y) + def testSumGrad(self): + self._TestGradient(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x) + class BroadcastTest(NcclTestCase): def testBroadcast(self): self._Test(_NcclBroadcast, lambda x, y: x) + def testBroadcastSingleDevice(self): + # Broadcasts on a single device are removed completely during rewrite. + self._Test(_NcclBroadcast, lambda x, y: x, + (['/device:GPU:0', '/device:GPU:0'],)) + + def testBroadcastToCpuError(self): + # Broadcasts to CPU is not supported. + with self.assertRaisesRegexp( + errors.NotFoundError, + "No registered '_NcclBroadcastRecv' OpKernel for CPU devices"): + self._Test(_NcclBroadcast, lambda x, y: x, + (['/device:GPU:0', '/device:CPU:0'],)) + class CombinedTest(NcclTestCase): """Test all-reduce vs. single-reduce plus broadcast in one session.run.""" - def _combined(self, tensors, devices): - all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)[0] - single_reduce_tensors, single_reduce_ops = _NcclReduce( - nccl.reduce_sum, tensors, devices) - broadcast_tensors, broadcast_ops = _NcclBroadcast(single_reduce_tensors, - devices) - all_tensors = all_reduce_tensors + single_reduce_tensors + broadcast_tensors - return all_tensors, single_reduce_ops + broadcast_ops + def _Combined(self, tensors, devices): + all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices) + single_reduce_tensors = _NcclReduce(nccl.reduce_sum, tensors, devices) + broadcast_tensors = _NcclBroadcast(single_reduce_tensors, devices) + return all_reduce_tensors + broadcast_tensors def testCombined(self): - self._Test(self._combined, lambda x, y: x + y) + self._Test(self._Combined, lambda x, y: x + y) if __name__ == '__main__': diff --git a/tensorflow/contrib/nn/__init__.py b/tensorflow/contrib/nn/__init__.py index be0957f473ce6a6457267cdbca036363e6904e98..3bf795d19aad73ec37c0485fe1900a7d8ac43137 100644 --- a/tensorflow/contrib/nn/__init__.py +++ b/tensorflow/contrib/nn/__init__.py @@ -18,7 +18,9 @@ @@deprecated_flipped_softmax_cross_entropy_with_logits @@deprecated_flipped_sparse_softmax_cross_entropy_with_logits @@deprecated_flipped_sigmoid_cross_entropy_with_logits +@@nth_element @@rank_sampled_softmax_loss +@@scaled_softplus """ from __future__ import absolute_import @@ -30,6 +32,7 @@ from tensorflow.contrib.nn.python.ops.alpha_dropout import * from tensorflow.contrib.nn.python.ops.cross_entropy import * from tensorflow.contrib.nn.python.ops.sampling_ops import * from tensorflow.contrib.nn.python.ops.scaled_softplus import * +from tensorflow.python.ops.nn_ops import nth_element # pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/nn/python/ops/scaled_softplus.py b/tensorflow/contrib/nn/python/ops/scaled_softplus.py index 5fc11d8ec66fc6e32d5028a6b6181c784104db06..fcbfbc239ca5b8a1d4b17b403f99b7eb05db47b0 100644 --- a/tensorflow/contrib/nn/python/ops/scaled_softplus.py +++ b/tensorflow/contrib/nn/python/ops/scaled_softplus.py @@ -20,58 +20,96 @@ from __future__ import print_function from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -def scaled_softplus(x, alpha, name=None): - """Returns `alpha * ln(1 + exp(x / alpha))`, for scalar `alpha > 0`. +def _reduce_and_reshape_grad(g, t): + """Returns the gradient, sum-reduced and reshaped to `t`'s shape.""" + shape = array_ops.shape(t) + g_shape = array_ops.shape(g) + # pylint: disable=protected-access + bcast_dims, _ = gen_array_ops._broadcast_gradient_args(shape, g_shape) + # pylint: enable=protected-access + return array_ops.reshape(math_ops.reduce_sum(g, bcast_dims), shape) + + +def scaled_softplus(x, alpha, clip=None, name=None): + """Returns `y = alpha * ln(1 + exp(x / alpha))` or `min(y, clip)`. This can be seen as a softplus applied to the scaled input, with the output appropriately scaled. As `alpha` tends to 0, `scaled_softplus(x, alpha)` tends - to `relu(x)`. + to `relu(x)`. The clipping is optional. As alpha->0, scaled_softplus(x, alpha) + tends to relu(x), and scaled_softplus(x, alpha, clip=6) tends to relu6(x). Note: the gradient for this operation is defined to depend on the backprop inputs as well as the outputs of this operation. Args: x: A `Tensor` of inputs. - alpha: A scalar `Tensor`, indicating the amount of smoothness. The caller + alpha: A `Tensor`, indicating the amount of smoothness. The caller must ensure that `alpha > 0`. + clip: (optional) A `Tensor`, the upper bound to clip the values. name: A name for the scope of the operations (optional). Returns: - A tensor of same size and type as `x`. + A tensor of the size and type determined by broadcasting of the inputs. """ - with ops.name_scope(name, 'scaled_softplus', [x, alpha]): + clipping = clip is not None + with ops.name_scope(name, 'scaled_softplus', + [x, alpha] + ([clip] if clipping else [])): x = ops.convert_to_tensor(x, name='x') dtype = x.dtype alpha = ops.convert_to_tensor(alpha, dtype=dtype, name='alpha') - # Verify that alpha is a scalar. - alpha.get_shape().assert_has_rank(0) + # Compute the forward value. + y = alpha * nn.softplus(x / alpha) + if clipping: + clip = ops.convert_to_tensor(clip, dtype=dtype, name='clip') + y = math_ops.minimum(y, clip) def _grad(op, g): - """Backprop for scaled softplus.""" - y = op.outputs[0] - alpha = op.inputs[1] - # Prevent the expensive computations from happening before g is available. + """Backprop for scaled softplus, with optional clipping.""" + y, x, alpha = op.inputs[:3] + # Prevent the memory-expensive computations from happening before g is + # available. with ops.control_dependencies([g]): - y /= alpha + y = array_ops.identity(y) + clip_grad = [] + if clipping: + clip = op.inputs[3] + unclipped = math_ops.cast(y < clip, g.dtype) + clip_grad = [_reduce_and_reshape_grad(g * (1. - unclipped), clip)] + g *= unclipped + y /= alpha emy = math_ops.exp(-y) dy_dx = 1. - emy # The eps below avoids log(0). Note that t*log(t) -> 0 as t->0. eps = 1e-8 dy_dalpha = y * emy - dy_dx * math_ops.log(dy_dx + eps) - return g * dy_dx, math_ops.reduce_sum(g * dy_dalpha) + # Backprop to the actual inputs, but not to the output. + return [None, + _reduce_and_reshape_grad(g * dy_dx, x), + _reduce_and_reshape_grad(g * dy_dalpha, alpha)] + clip_grad - @function.Defun(dtype, dtype, - func_name='ScaledSoftplus_%s' % dtype.name, - shape_func=lambda op: [op.inputs[0].get_shape()], + if clipping: + @function.Defun(dtype, dtype, dtype, dtype, + func_name='ScaledSoftplusHelper_clip_%s' % dtype.name, + shape_func=lambda op: [op.inputs[0].shape], + python_grad_func=_grad) + def _forward_helper_clip(y, x, alpha, clip): + del x, alpha, clip # Unused. + return y + return _forward_helper_clip(y, x, alpha, clip) + # No clipping. + @function.Defun(dtype, dtype, dtype, + func_name='ScaledSoftplusHelper_%s' % dtype.name, + shape_func=lambda op: [op.inputs[0].shape], python_grad_func=_grad) - def _forward(x, alpha): - """Forward computation of scaled softplus.""" - return alpha * nn.softplus(x / alpha) - - return _forward(x, alpha) + def _forward_helper(y, x, alpha): + del x, alpha # Unused. + return y + return _forward_helper(y, x, alpha) diff --git a/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py b/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py index 3a459330ceb774b033ff6622b4c90807c782f06f..b978343c6a79af856d833b0ab8002c256ce478e0 100644 --- a/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py +++ b/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py @@ -33,10 +33,11 @@ class ScaledSoftplusTest(test.TestCase): x = np.random.randn(3, 4).astype(np.float32) x64 = np.random.randn(3, 4).astype(np.float64) alpha = np.random.rand() + 0.01 - y = alpha * np.log(1. + np.exp(x / alpha)) + clip = np.float32(0.1) + y = np.minimum(alpha * np.log(1. + np.exp(x / alpha)), clip) y64 = alpha * np.log(1. + np.exp(x64 / alpha)) with self.test_session(use_gpu=True) as sess: - z = scaled_softplus(constant_op.constant(x), alpha) + z = scaled_softplus(constant_op.constant(x), alpha, clip) z64 = scaled_softplus(constant_op.constant(x64), alpha) z, z64 = sess.run([z, z64]) eps = 1e-6 @@ -47,18 +48,28 @@ class ScaledSoftplusTest(test.TestCase): np.random.seed(1) # Make it reproducible. x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) - alpha_np = np.float32(np.random.rand() + 0.01) + alpha_np = np.float32(np.random.rand(1, x_shape[1]) + 0.01) + clip_np = np.float32(np.random.rand(x_shape[0], 1) * 5.) with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np) alpha_tf = constant_op.constant(alpha_np) + clip_tf = constant_op.constant(clip_np) y_tf = scaled_softplus(x_tf, alpha_tf) + z_tf = scaled_softplus(x_tf, alpha_tf, clip_tf * 0.1) err = gradient_checker.compute_gradient_error([x_tf, alpha_tf], - [x_shape, []], + [x_shape, alpha_np.shape], y_tf, x_shape, [x_np, alpha_np], - delta=1e-2) - eps = 1e-4 + delta=0.002) + err_clip = gradient_checker.compute_gradient_error( + [x_tf, alpha_tf, clip_tf], + [x_shape, alpha_np.shape, clip_np.shape], + z_tf, x_shape, + [x_np, alpha_np, clip_np], + delta=0.002) + eps = 2e-4 self.assertLess(err, eps) + self.assertLess(err_clip, eps) if __name__ == '__main__': diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index b5a67206f3433ab3cf5ee5594557aadf8a09983b..8b2b31d5bc09778681881f2ca68b24c16bcff3d5 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -145,6 +145,9 @@ tf_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], + tags = [ + "no_oss", # Flaky due to port collisions + ], ) filegroup( diff --git a/tensorflow/contrib/predictor/predictor.py b/tensorflow/contrib/predictor/predictor.py index dbc0028259ebe50bdbe8dee9ef3ccff1aff5507c..28fa815684dd5e242f82d51968d856553315e8d5 100644 --- a/tensorflow/contrib/predictor/predictor.py +++ b/tensorflow/contrib/predictor/predictor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Abstract base class for all predictors.""" from __future__ import absolute_import @@ -66,8 +65,9 @@ class Predictor(object): expected_keys = set(self.feed_tensors.keys()) unexpected_keys = input_keys - expected_keys if unexpected_keys: - raise ValueError('Got unexpected keys in input_dict: {}'.format( - unexpected_keys)) + raise ValueError( + 'Got unexpected keys in input_dict: {}\nexpected: {}'.format( + unexpected_keys, expected_keys)) feed_dict = {} for key in self.feed_tensors.keys(): diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 7ff186bc2ad7204d934c322a04ad1c3f2aa383ab..2c0ffaf6c0525db2aab4932d81eb784d77c64d16 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -13,6 +13,34 @@ py_library( deps = [], ) +py_library( + name = "graph_matcher", + srcs = [ + "python/graph_matcher.py", + ], + srcs_version = "PY2AND3", + deps = [], +) + +py_test( + name = "graph_matcher_test", + size = "small", + srcs = ["python/graph_matcher_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":graph_matcher", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + py_library( name = "input_to_ops", srcs = ["python/input_to_ops.py"], @@ -43,6 +71,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":common", + ":graph_matcher", ":input_to_ops", "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/python:array_ops", @@ -58,6 +87,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":fold_batch_norms", + ":graph_matcher", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", @@ -147,10 +177,11 @@ py_test( py_test( name = "quantize_parameterized_test", - size = "medium", + size = "large", srcs = ["python/quantize_parameterized_test.py"], srcs_version = "PY2AND3", deps = [ + ":fold_batch_norms", ":quantize", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", @@ -188,9 +219,13 @@ py_test( srcs_version = "PY2AND3", deps = [ ":quantize_graph", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", "//tensorflow/python:variables", ], diff --git a/tensorflow/contrib/quantize/__init__.py b/tensorflow/contrib/quantize/__init__.py index f137723cb6636cc60138064b762d3b38aaac3511..5d4e4575c935e0a888c6e5e4d0db640d93e1bd49 100644 --- a/tensorflow/contrib/quantize/__init__.py +++ b/tensorflow/contrib/quantize/__init__.py @@ -25,7 +25,7 @@ from tensorflow.contrib.quantize.python.quantize_graph import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - "create_eval_graph," + "create_eval_graph", "create_training_graph", ] diff --git a/tensorflow/contrib/quantize/python/copy_graph_test.py b/tensorflow/contrib/quantize/python/copy_graph_test.py index 0889f12de6aac53f70ecfa7b70fc19ac7b95a5fe..7ff9ad9f8412d7076bf12d6cf10772244444013f 100644 --- a/tensorflow/contrib/quantize/python/copy_graph_test.py +++ b/tensorflow/contrib/quantize/python/copy_graph_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tensorflow.quantized.mangle.copy_graph.""" +"""Tests for copy_graph.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index c9d16fb32927855aa14b8b8b33457063e26f6e4d..647d4044001f7be701037d07dc46db86c0aa3a0e 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -21,7 +21,9 @@ from __future__ import print_function import re from tensorflow.contrib import graph_editor from tensorflow.contrib.quantize.python import common +from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn @@ -29,7 +31,270 @@ from tensorflow.python.ops import nn_ops def FoldBatchNorms(graph): - """Finds batch norm layers in the graph, folds them into preceding layers. + """Finds batch norm layers and folds them into preceding layers. + + Folding only affects the following layers: Conv2D, fully connected, depthwise + convolution. + + Args: + graph: Graph to walk and modify. + + Raises: + ValueError: When batch norm folding fails. + """ + _FoldFusedBatchNorms(graph) + _FoldUnfusedBatchNorms(graph) + + +def _FoldFusedBatchNorms(graph): + """Finds fused batch norm layers and folds them into preceding layers. + + Folding only affects the following layers: Conv2D, fully connected, depthwise + convolution. + + Args: + graph: Graph to walk and modify. + + Raises: + ValueError: When batch norm folding fails. + """ + for match in _FindFusedBatchNorms(graph): + scope, sep, _ = match.layer_op.name.rpartition('/') + # Make sure new ops are added to `graph` and put on the same device as + # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope + # named `scope`. Otherwise, TF creates a unique scope whose name starts with + # `scope`. + with graph.as_default(), graph.name_scope(scope + sep), ops.device( + match.bn_op.device): + # new weights = old weights * gamma / sqrt(variance + epsilon) + # new biases = -mean * gamma / sqrt(variance + epsilon) + beta + multiplier_tensor = match.gamma_tensor * math_ops.rsqrt( + match.variance_tensor + match.bn_op.get_attr('epsilon')) + bias_tensor = math_ops.subtract( + match.beta_tensor, match.mean_tensor * multiplier_tensor, name='bias') + + # The shape of depthwise weights is different, so we need to reshape the + # multiplier_tensor to ensure that the scaled_weight_tensor has the + # expected shape. + if match.layer_op.type == 'DepthwiseConv2dNative': + new_shape = [ + match.weight_tensor.get_shape().as_list()[2], + match.weight_tensor.get_shape().as_list()[3] + ] + multiplier_tensor = array_ops.reshape( + multiplier_tensor, new_shape, name='scale_reshape') + + # TODO(suharshs): This naming of the following ops needs to carefully + # follow the naming expected by quantize.py. Generalize the quantize code + # to not require these delicate naming conventions. + scaled_weight_tensor = math_ops.multiply( + match.weight_tensor, multiplier_tensor, name='mul_fold') + + new_layer_tensor = _CloneWithNewOperands( + match.layer_op, match.input_tensor, scaled_weight_tensor) + + bias_add_tensor = math_ops.add( + new_layer_tensor, bias_tensor, name='add_fold') + + nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor, + match.output_tensor) + if nodes_modified_count != 1: + raise ValueError( + 'Unexpected inputs to op: %s' % match.output_tensor.name) + + +def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): + """Clones layer_op with input_tensor and weight_tensor as new inputs.""" + new_layer_name = layer_op.name.split('/')[-1] + '_Fold' + if layer_op.type == 'Conv2D': + return nn_ops.conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), + data_format=layer_op.get_attr('data_format'), + name=new_layer_name) + elif layer_op.type == 'MatMul': + return math_ops.matmul( + input_tensor, + weight_tensor, + transpose_a=layer_op.get_attr('transpose_a'), + transpose_b=layer_op.get_attr('transpose_b'), + name=new_layer_name) + elif layer_op.type == 'DepthwiseConv2dNative': + return nn.depthwise_conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + name=new_layer_name) + else: + raise ValueError('Cannot handle operation of type: %s' % layer_op.type) + + +def _FindFusedBatchNorms(graph): + """Finds all ops and tensors related to found FusedBatchNorms. + + Args: + graph: Graph to inspect. + + Yields: + _FusedBatchNormMatches. + """ + input_pattern = graph_matcher.OpTypePattern('*') + weight_pattern = graph_matcher.OpTypePattern('*') + gamma_pattern = graph_matcher.OpTypePattern('*') + beta_pattern = graph_matcher.OpTypePattern('*') + mean_pattern = graph_matcher.OpTypePattern('*') + variance_pattern = graph_matcher.OpTypePattern('*') + + conv_pattern = graph_matcher.OpTypePattern( + 'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern]) + # MatMul has a Reshape between it and FusedBatchNorm. + matmul_pattern = graph_matcher.OpTypePattern( + 'MatMul', inputs=[input_pattern, weight_pattern]) + matmul_reshape_pattern = graph_matcher.OpTypePattern( + 'Reshape', inputs=[matmul_pattern, + graph_matcher.OpTypePattern('*')]) + + conv_batch_norm_pattern = graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + conv_pattern, gamma_pattern, beta_pattern, mean_pattern, + variance_pattern + ]) + matmul_batch_norm_pattern = graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern, + variance_pattern + ]) + matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern( + 'Reshape', + inputs=[matmul_batch_norm_pattern, + graph_matcher.OpTypePattern('*')]) + + conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern) + matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern) + + def _GetCommonTensors(match_result): + """Gets tensors needed for FusedBatchNormMatch from match_result.""" + input_tensor = match_result.get_tensor(input_pattern) + weight_tensor = match_result.get_tensor(weight_pattern) + gamma_tensor = match_result.get_tensor(gamma_pattern) + beta_tensor = match_result.get_tensor(beta_pattern) + # FusedBatchNorm in training is different from that in inference. It takes + # empty 'mean' and empty 'variance', and produces the mean and the variance + # of the batch. Therefore, when is_training is true, mean_tensor and + # variance_tensor point to 1st and 2nd (0-based) output of bn_op, + # respectively; when is_training is false, they point to bn_op's inputs. + is_training = bn_op.get_attr('is_training') + if is_training: + mean_tensor = bn_op.outputs[1] + variance_tensor = bn_op.outputs[2] + else: + mean_tensor = match_result.get_tensor(mean_pattern) + variance_tensor = match_result.get_tensor(variance_pattern) + return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) + + for match_result in conv_matcher.match_graph(graph): + layer_op = match_result.get_op(conv_pattern) + bn_op = match_result.get_op(conv_batch_norm_pattern) + # In the case of convolution the output_tensor is the output of bn_op. + output_tensor = bn_op.outputs[0] + + (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) = _GetCommonTensors(match_result) + yield _FusedBatchNormMatch( + layer_op=layer_op, + bn_op=bn_op, + output_tensor=output_tensor, + input_tensor=input_tensor, + weight_tensor=weight_tensor, + gamma_tensor=gamma_tensor, + beta_tensor=beta_tensor, + mean_tensor=mean_tensor, + variance_tensor=variance_tensor) + + for match_result in matmul_matcher.match_graph(graph): + layer_op = match_result.get_op(matmul_pattern) + bn_op = match_result.get_op(matmul_batch_norm_pattern) + # In the MatMul case, the output of batch norm is reshaped back into a + # 2D tensor, so the output_tensor is the output of the Reshape op. + output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) + output_tensor = output_reshape_op.outputs[0] + + (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) = _GetCommonTensors(match_result) + yield _FusedBatchNormMatch( + layer_op=layer_op, + bn_op=bn_op, + output_tensor=output_tensor, + input_tensor=input_tensor, + weight_tensor=weight_tensor, + gamma_tensor=gamma_tensor, + beta_tensor=beta_tensor, + mean_tensor=mean_tensor, + variance_tensor=variance_tensor) + + +class _FusedBatchNormMatch(object): + """Contains all information related to a found FusedBatchNorm.""" + + def __init__(self, layer_op, bn_op, output_tensor, input_tensor, + weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor): + self._layer_op = layer_op + self._bn_op = bn_op + self._output_tensor = output_tensor + self._input_tensor = input_tensor + self._weight_tensor = weight_tensor + self._gamma_tensor = gamma_tensor + self._beta_tensor = beta_tensor + self._mean_tensor = mean_tensor + self._variance_tensor = variance_tensor + + @property + def layer_op(self): + return self._layer_op + + @property + def bn_op(self): + return self._bn_op + + @property + def output_tensor(self): + return self._output_tensor + + @property + def input_tensor(self): + return self._input_tensor + + @property + def weight_tensor(self): + return self._weight_tensor + + @property + def gamma_tensor(self): + return self._gamma_tensor + + @property + def beta_tensor(self): + return self._beta_tensor + + @property + def mean_tensor(self): + return self._mean_tensor + + @property + def variance_tensor(self): + return self._variance_tensor + + +def _FoldUnfusedBatchNorms(graph): + """Finds unfused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise convolution. diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index 4f11188a551fa7054bf7c91f70ec9f3f591a4c8e..2cecf6851467f82675bd67bf1fb108e9a39df1b0 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.python.framework import dtypes @@ -35,29 +34,32 @@ conv2d = layers.conv2d fully_connected = layers.fully_connected separable_conv2d = layers.separable_conv2d -_DEFAULT_BATCH_NORM_PARAMS = { - 'center': True, - 'scale': True, - 'decay': 1.0 - 0.003, - 'fused': False, -} - # TODO(suharshs): Use parameterized test once OSS TF supports it. class FoldBatchNormsTest(test_util.TensorFlowTestCase): def _RunTestOverParameters(self, test_fn): parameters_list = [ - # (relu, relu_op_name, with_bypass) - (nn_ops.relu6, 'Relu6', False), - (nn_ops.relu, 'Relu', False), - (nn_ops.relu6, 'Relu6', True), - (nn_ops.relu, 'Relu', True), + # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm) + (nn_ops.relu6, 'Relu6', False, False, False), + (nn_ops.relu, 'Relu', False, False, False), + (nn_ops.relu6, 'Relu6', True, False, False), + (nn_ops.relu, 'Relu', True, False, False), + (nn_ops.relu6, 'Relu6', False, True, False), + (nn_ops.relu, 'Relu', False, True, False), + (nn_ops.relu6, 'Relu6', True, True, False), + (nn_ops.relu, 'Relu', True, True, False), + # Fused batch norm always has scaling enabled. + (nn_ops.relu6, 'Relu6', False, True, True), + (nn_ops.relu, 'Relu', False, True, True), + (nn_ops.relu6, 'Relu6', True, True, True), + (nn_ops.relu, 'Relu', True, True, True), ] - for parameters in parameters_list: - test_fn(parameters[0], parameters[1], parameters[2]) + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3], params[4]) - def _TestFoldConv2d(self, relu, relu_op_name, with_bypass): + def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling, + fused_batch_norm): """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. Args: @@ -65,6 +67,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -74,12 +78,17 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): stride = 1 if with_bypass else 2 activation_fn = None if with_bypass else relu scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) + node = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) @@ -88,12 +97,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/mul']) - self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) + self._AssertInputOpsAre(folded_mul, [ + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) + ]) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold']) - folded_conv = g.get_operation_by_name(scope + '/convolution_Fold') + folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold') self.assertEqual(folded_conv.type, 'Conv2D') self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) @@ -101,16 +111,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/convolution_Fold', - scope + '/BatchNorm/batchnorm/sub']) + self._AssertInputOpsAre(folded_add, [ + scope + '/Conv2D_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) def testFoldConv2d(self): self._RunTestOverParameters(self._TestFoldConv2d) - def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass): + def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. Tests that folding works even with an input shape where some dimensions are @@ -121,6 +133,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -137,7 +151,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): weights_initializer=self._WeightInit(0.09), activation_fn=activation_fn, normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') @@ -148,11 +163,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') self._AssertInputOpsAre(folded_mul, [ - scope + '/weights/read', scope + '/BatchNorm/batchnorm/mul' + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) ]) - self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold']) - folded_conv = g.get_operation_by_name(scope + '/convolution_Fold') + folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold') self.assertEqual(folded_conv.type, 'Conv2D') self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) @@ -160,7 +176,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') self._AssertInputOpsAre(folded_add, [ - scope + '/convolution_Fold', scope + '/BatchNorm/batchnorm/sub' + scope + '/Conv2D_Fold', + self._BathNormBiasName(scope, fused_batch_norm) ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) @@ -168,62 +185,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): def testFoldConv2dUnknownShape(self): self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) - def _TestFoldConv2dWithoutScale(self, relu, relu_op_name, with_bypass): - """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. - - Args: - relu: Callable that returns an Operation, a factory method for the Relu*. - relu_op_name: String, name of the Relu* operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Relu*. - """ - g = ops.Graph() - with g.as_default(): - batch_size, height, width = 5, 128, 128 - inputs = array_ops.zeros((batch_size, height, width, 3)) - out_depth = 3 if with_bypass else 32 - stride = 1 if with_bypass else 2 - activation_fn = None if with_bypass else relu - bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) - bn_params['scale'] = False - scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=bn_params, - scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - relu(node, name='test/' + relu_op_name) - - fold_batch_norms.FoldBatchNorms(g) - - folded_mul = g.get_operation_by_name(scope + '/mul_fold') - self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/Rsqrt']) - self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) - - folded_conv = g.get_operation_by_name(scope + '/convolution_Fold') - self.assertEqual(folded_conv.type, 'Conv2D') - self._AssertInputOpsAre(folded_conv, - [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) - - folded_add = g.get_operation_by_name(scope + '/add_fold') - self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/convolution_Fold', - scope + '/BatchNorm/batchnorm/sub']) - output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] - self._AssertOutputGoesToOps(folded_add, g, output_op_names) - - def testFoldConv2dWithoutScale(self): - self._RunTestOverParameters(self._TestFoldConv2dWithoutScale) - - def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass): + def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): """Tests folding cases: inputs -> FC with batch norm -> Relu*. Args: @@ -231,6 +194,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -239,12 +204,15 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): out_depth = 256 if with_bypass else 128 activation_fn = None if with_bypass else relu scope = 'test/test2' if with_bypass else 'test' - node = fully_connected(inputs, out_depth, - weights_initializer=self._WeightInit(0.03), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) + node = fully_connected( + inputs, + out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) @@ -253,9 +221,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/mul']) + self._AssertInputOpsAre(folded_mul, [ + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) + ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold') @@ -266,71 +235,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/MatMul_Fold', - scope + '/BatchNorm/batchnorm/sub']) + self._AssertInputOpsAre(folded_add, [ + scope + '/MatMul_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) def testFoldFullyConnectedLayer(self): self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) - def _TestFoldFullyConnectedLayerWithoutScale(self, relu, relu_op_name, - with_bypass): - """Tests folding cases: inputs -> FC with batch norm -> Relu*. - - Args: - relu: Callable that returns an Operation, a factory method for the Relu*. - relu_op_name: String, name of the Relu* operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Relu*. - """ - g = ops.Graph() - with g.as_default(): - batch_size, depth = 5, 256 - inputs = array_ops.zeros((batch_size, depth)) - out_depth = 256 if with_bypass else 128 - activation_fn = None if with_bypass else relu - bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) - bn_params['scale'] = False - scope = 'test/test2' if with_bypass else 'test' - node = fully_connected(inputs, out_depth, - weights_initializer=self._WeightInit(0.03), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=bn_params, - scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - relu(node, name='test/' + relu_op_name) - - fold_batch_norms.FoldBatchNorms(g) - - folded_mul = g.get_operation_by_name(scope + '/mul_fold') - self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/Rsqrt']) - self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) - - folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold') - self.assertEqual(folded_conv.type, 'MatMul') - self._AssertInputOpsAre(folded_conv, - [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) - - folded_add = g.get_operation_by_name(scope + '/add_fold') - self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/MatMul_Fold', - scope + '/BatchNorm/batchnorm/sub']) - output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] - self._AssertOutputGoesToOps(folded_add, g, output_op_names) - - def testFoldFullyConnectedLayerWithoutScale(self): - self._RunTestOverParameters(self._TestFoldFullyConnectedLayerWithoutScale) - - def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass): + def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. Args: @@ -338,6 +254,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -346,13 +264,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): stride = 1 if with_bypass else 2 activation_fn = None if with_bypass else relu scope = 'test/test2' if with_bypass else 'test' - node = separable_conv2d(inputs, None, [5, 5], stride=stride, - depth_multiplier=1.0, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) + node = separable_conv2d( + inputs, + None, [5, 5], + stride=stride, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) @@ -368,9 +291,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): scale_reshape = g.get_operation_by_name(scope + '/scale_reshape') self.assertEqual(scale_reshape.type, 'Reshape') - self._AssertInputOpsAre(scale_reshape, - [scope + '/BatchNorm/batchnorm/mul', - scope + '/scale_reshape/shape']) + self._AssertInputOpsAre(scale_reshape, [ + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm), + scope + '/scale_reshape/shape' + ]) self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') @@ -381,77 +305,35 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/depthwise_Fold', - scope + '/BatchNorm/batchnorm/sub']) + self._AssertInputOpsAre(folded_add, [ + scope + '/depthwise_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) def testFoldDepthwiseConv2d(self): self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) - def _TestFoldDepthwiseConv2dWithoutScale(self, relu, relu_op_name, - with_bypass): - """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. - - Args: - relu: Callable that returns an Operation, a factory method for the Relu*. - relu_op_name: String, name of the Relu* operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Relu*. - """ - g = ops.Graph() - with g.as_default(): - batch_size, height, width = 5, 128, 128 - inputs = array_ops.zeros((batch_size, height, width, 3)) - stride = 1 if with_bypass else 2 - activation_fn = None if with_bypass else relu - bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) - bn_params['scale'] = False - scope = 'test/test2' if with_bypass else 'test' - node = separable_conv2d(inputs, None, [5, 5], stride=stride, - depth_multiplier=1.0, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=bn_params, - scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - relu(node, name='test/' + relu_op_name) - - fold_batch_norms.FoldBatchNorms(g) - - folded_mul = g.get_operation_by_name(scope + '/mul_fold') - self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/depthwise_weights/read', - scope + '/scale_reshape']) - self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold']) - - scale_reshape = g.get_operation_by_name(scope + '/scale_reshape') - self.assertEqual(scale_reshape.type, 'Reshape') - self._AssertInputOpsAre(scale_reshape, - [scope + '/BatchNorm/batchnorm/Rsqrt', - scope + '/scale_reshape/shape']) - self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) - - folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') - self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') - self._AssertInputOpsAre(folded_conv, - [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) - - folded_add = g.get_operation_by_name(scope + '/add_fold') - self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/depthwise_Fold', - scope + '/BatchNorm/batchnorm/sub']) - output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] - self._AssertOutputGoesToOps(folded_add, g, output_op_names) - - def testFoldDepthwiseConv2dWithoutScale(self): - self._RunTestOverParameters(self._TestFoldDepthwiseConv2dWithoutScale) + def _BatchNormParams(self, scale=True, fused=False): + return { + 'center': True, + 'scale': scale, + 'decay': 1.0 - 0.003, + 'fused': fused + } + + def _BatchNormMultiplierName(self, scope, has_scaling, fused): + if has_scaling: + if fused: + return scope + '/mul' + return scope + '/BatchNorm/batchnorm/mul' + return scope + '/BatchNorm/batchnorm/Rsqrt' + + def _BathNormBiasName(self, scope, fused): + if fused: + return scope + '/bias' + return scope + '/BatchNorm/batchnorm/sub' def _WeightInit(self, stddev): """Returns a truncated normal variable initializer. diff --git a/tensorflow/contrib/quantize/python/graph_matcher.py b/tensorflow/contrib/quantize/python/graph_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..e3581cc55905a0af7d0464bc0ec673d3ed7f0363 --- /dev/null +++ b/tensorflow/contrib/quantize/python/graph_matcher.py @@ -0,0 +1,200 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities that match patterns in a tf.Graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class OpTypePattern(object): + """A tree pattern that matches TF expressions with certain op types.""" + + def __init__(self, op_type, name=None, inputs=None): + """Initializes an OpTypePattern. + + Args: + op_type: string that specifies the allowed types of the root. It can be + (1) an op type, e.g. 'Conv2D', + (2) '*', i.e. wildcard, or + (3) multiple op types separated by '|', e.g., 'Relu|Relu6'. + We could use regex strings, which might be worthwhile when we have many + similar TF op types. + name: Optional string. The name of the pattern that can be looked up in + MatchResult. + inputs: Optional list of `OpTypePattern`s or strings that specify the + patterns for the inputs of a matching op. If None, this pattern accepts + any inputs of a matching op. + """ + self._op_type = op_type + self._name = name + if inputs is None: + inputs = [] + self._inputs = [ + input_pattern if isinstance(input_pattern, OpTypePattern) else + OpTypePattern(input_pattern) for input_pattern in inputs + ] + + @property + def op_type(self): + return self._op_type + + @property + def inputs(self): + return self._inputs + + @property + def name(self): + return self._name + + +class MatchResult(object): + r"""Encapsulates the result of a match done by GraphMatcher. + + MatchResult contains a map from OpTypePattern to the matching op and tensor. + When the matching op has multiple output tensors, the matching tensor is the + output tensor used by the matching op of the parent pattern. E.g., when we + match graph + + - + + / \y0 y1/ \ + x split z + | + y (nodes are ops; edges are going up) + + against add_pattern defined as + + y1_pattern = OpTypePattern('*') + z_pattern = OpTypePattern('*') + add_pattern = OpTypePattern('+', inputs=[y1_pattern, z_pattern]) + + the matching op of `y1_pattern` is `split`, and the matching tensor of + `y1_pattern` + is `y1` not `y0`. + """ + + def __init__(self): + self._pattern_to_op_tensor = {} + self._name_to_pattern = {} + + def add(self, pattern, op, tensor): + self._pattern_to_op_tensor[pattern] = op, tensor + if pattern.name is not None: + if pattern.name in self._name_to_pattern: + raise ValueError( + 'Name %s is already bound to another pattern' % pattern.name) + self._name_to_pattern[pattern.name] = pattern + + def _to_pattern(self, pattern_or_name): + if isinstance(pattern_or_name, OpTypePattern): + return pattern_or_name + + if isinstance(pattern_or_name, str): + return self._name_to_pattern[pattern_or_name] + + raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.' + % type(pattern_or_name)) + + def get_op(self, pattern_or_name): + return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][0] + + def get_tensor(self, pattern_or_name): + return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][1] + + +class GraphMatcher(object): + """Checks if a particular subgraph matches a given pattern.""" + + def __init__(self, pattern): + """Initializes a GraphMatcher. + + Args: + pattern: The `OpTypePattern` against which `GraphMatcher` matches + subgraphs. + """ + self._pattern = pattern + + def _match_pattern(self, pattern, op, tensor): + """Returns whether an TF expression rooted at `op` matches `pattern`. + + If there is a match, adds to `self._match_result` the matching op and tensor + with key `pattern`. + + Args: + pattern: An `OpTypePattern`. + op: A `tf.Operation` to match against the pattern. + tensor: the output `tf.Tensor` of `op` that is used by the matching op of + `pattern`'s parent. Can be None if `pattern` is already the root of the + pattern tree. + + Returns: + True if an TF expression rooted at `op` matches `pattern`. + """ + if pattern.op_type != '*': + if op.type not in pattern.op_type.split('|'): + return False + + self._match_result.add(pattern, op, tensor) + + if not pattern.inputs: + # If pattern.inputs is empty, skips the rest and accepts all the inputs. + return True + + return len(op.inputs) == len(pattern.inputs) and all([ + self._match_pattern(input_pattern, input_tensor.op, input_tensor) + for input_tensor, input_pattern in zip(op.inputs, pattern.inputs) + ]) + + def match_op(self, op): + """Matches `op` against `self._pattern`. + + Args: + op: `tf.Operation` to match against the pattern. + + Returns: + Returns a `MatchResult` if `op` matches the pattern; otherwise, returns + None. + """ + self._match_result = MatchResult() + if not self._match_pattern(self._pattern, op, tensor=None): + return None + return self._match_result + + def match_ops(self, ops): + """Matches each operation in `ops` against `self._pattern`. + + Args: + ops: collection of `tf.Operation` to match against the pattern. + + Yields: + `MatchResult` for each `tf.Operation` that matches the pattern. + """ + for op in ops: + match_result = self.match_op(op) + if match_result: + yield match_result + + def match_graph(self, graph): + """Matches each operation in `graph` against `self._pattern`. + + Args: + graph: `tf.Graph` containing operations to match. + + Yields: + `MatchResult` for each `tf.Operation` in `graph` that matches the pattern. + """ + # Python 3.3.2+ implements `yield from`, but for now: + for match_result in self.match_ops(graph.get_operations()): + yield match_result diff --git a/tensorflow/contrib/quantize/python/graph_matcher_test.py b/tensorflow/contrib/quantize/python/graph_matcher_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e1572865e423e569ee3b280036c0e02b71b70648 --- /dev/null +++ b/tensorflow/contrib/quantize/python/graph_matcher_test.py @@ -0,0 +1,130 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for graph_matcher.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python import ops as contrib_ops +from tensorflow.contrib.layers.python.layers import initializers +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import graph_matcher +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +class GraphMatcherTest(test_util.TensorFlowTestCase): + + def test_conv_layer(self): + g = ops.Graph() + with g.as_default(): + inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3]) + + with contrib_ops.arg_scope( + [layers.batch_norm], fused=True, is_training=True, trainable=True): + return layers.convolution( + inputs, + num_outputs=16, + kernel_size=3, + stride=1, + padding='VALID', + activation_fn=nn_ops.relu, + normalizer_fn=layers.batch_norm, + normalizer_params={}, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + trainable=True, + scope=None) + + inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs') + relu_pattern = graph_matcher.OpTypePattern( + 'Relu', + name='relu', + inputs=[ + graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + graph_matcher.OpTypePattern( + 'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*', + '*' + ]) + ]) + matcher = graph_matcher.GraphMatcher(relu_pattern) + match_results = list(matcher.match_graph(g)) + self.assertEqual(1, len(match_results)) + match_result = match_results[0] + self.assertEqual(match_result.get_tensor(inputs_pattern), inputs) + self.assertEqual(match_result.get_tensor('inputs'), inputs) + + def test_multiple_outputs(self): + # - + + # / \y0 y1/ \ + # x split z + # | + # y (nodes are ops; edges are going up) + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32, shape=[1], name='x') + y = array_ops.placeholder(dtypes.float32, shape=[2], name='y') + y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0) + z = array_ops.placeholder(dtypes.float32, shape=[1], name='z') + math_ops.add(x, y0) + math_ops.subtract(y1, z) + + y1_pattern = graph_matcher.OpTypePattern('*') + minus_pattern = graph_matcher.OpTypePattern('Sub', inputs=[y1_pattern, '*']) + matcher = graph_matcher.GraphMatcher(minus_pattern) + + match_results = list(matcher.match_graph(g)) + self.assertEqual(1, len(match_results)) + match_result = match_results[0] + + self.assertEqual(y0.op, y1.op) + self.assertEqual(match_result.get_op(y1_pattern), y1.op) + self.assertEqual(match_result.get_tensor(y1_pattern), y1) + + def test_oneof_pattern(self): + # - + + # / \ / \ + # x y z + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32, shape=[], name='x') + y = array_ops.placeholder(dtypes.float32, shape=[], name='y') + z = array_ops.placeholder(dtypes.float32, shape=[], name='z') + plus = x + y + minus = y - z + + add_or_sub_pattern = graph_matcher.OpTypePattern( + 'Add|Sub', inputs=['*', '*']) + matcher = graph_matcher.GraphMatcher(add_or_sub_pattern) + self.assertEqual([ + match_result.get_op(add_or_sub_pattern) + for match_result in matcher.match_graph(g) + ], [plus.op, minus.op]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py index aaf3e92b8ea518fbbe55628b856e0191c949c619..d647bb94e849c713c2aca93c53f372bae5857c43 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph.py +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -25,7 +25,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import variables -def _create_graph(input_graph, is_training, elements=None): +def _create_graph(input_graph, + is_training, + elements=None, + device_name_or_function=None): """Returns a transformed training input_graph for simulated quantization. The forward pass has fake quantization ops inserted to simulate the error @@ -36,12 +39,12 @@ def _create_graph(input_graph, is_training, elements=None): is_training: Whether quantizing training or eval graph. elements: (Optional) List of Tensors and Operations in input_graph whose corresponding elements in the new graph will be returned. + device_name_or_function: (Optional) The device name or function to use. Returns: - Returns a tuple(g, l) where: g is new tf.Graph that is rewritten for simulated quantization. l is a list of Tensors/Operations in g corresponding to the provided input - elements. + elements, if elements is not None. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or @@ -49,11 +52,14 @@ def _create_graph(input_graph, is_training, elements=None): """ # TODO(suharshs): Describe the process in more detail in the doc string. g = copy_graph.CopyGraph(input_graph) - fold_batch_norms.FoldBatchNorms(g) - quantize.Quantize(g, is_training=is_training) - return_elements = [] + with g.as_default(): + with ops.device(device_name_or_function): + fold_batch_norms.FoldBatchNorms(g) + quantize.Quantize(g, is_training=is_training) if elements is None: - elements = [] + return g + + return_elements = [] for element in elements: if isinstance(element, (ops.Tensor, variables.Variable)): return_elements.append(g.get_tensor_by_name(element.name)) @@ -66,7 +72,9 @@ def _create_graph(input_graph, is_training, elements=None): return g, return_elements -def create_training_graph(input_graph, elements=None): +def create_training_graph(input_graph, + elements=None, + device_name_or_function=None): """Returns a transformed training input_graph for simulated quantization. The forward pass has fake quantization ops inserted to simulate the error @@ -76,21 +84,25 @@ def create_training_graph(input_graph, elements=None): input_graph: The tf.Graph to be transformed. elements: (Optional) List of Tensors and Operations in input_graph whose corresponding elements in the new graph will be returned. + device_name_or_function: (Optional) The device name or function to use. Returns: - Returns a tuple(g, l) where: g is new tf.Graph that is rewritten for simulated quantization. l is a list of Tensors/Operations in g corresponding to the provided input - elements. + elements, if elements is not None. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or tf.Operation. """ - return _create_graph(input_graph, True, elements) + return _create_graph( + input_graph=input_graph, + is_training=True, + elements=elements, + device_name_or_function=device_name_or_function) -def create_eval_graph(input_graph, elements=None): +def create_eval_graph(input_graph, elements=None, device_name_or_function=None): """Returns a transformed eval input_graph for simulated quantization. The forward pass has fake quantization ops inserted to simulate the error @@ -100,15 +112,19 @@ def create_eval_graph(input_graph, elements=None): input_graph: The tf.Graph to be transformed. elements: (Optional) List of Tensors and Operations in input_graph whose corresponding elements in the new graph will be returned. + device_name_or_function: (Optional) The device name or function to use. Returns: - Returns a tuple(g, l) where: g is new tf.Graph that is rewritten for simulated quantization. l is a list of Tensors/Operations in g corresponding to the provided input - elements. + elements, if elements is not None. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or tf.Operation. """ - return _create_graph(input_graph, False, elements) + return _create_graph( + input_graph=input_graph, + is_training=False, + elements=elements, + device_name_or_function=device_name_or_function) diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index 382076672a70c873ae7c1384e0706231a0ba8a55..3407ace3914fe2de2506a2952ea5d1bf19028bb9 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -18,29 +18,41 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import quantize_graph from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -class QuantizeTest(test_util.TensorFlowTestCase): +class QuantizeGraphTest(test_util.TensorFlowTestCase): # We have a lot of other tests that test the details of the rewrite, here we # just the specific features of the quantize_graph API. def testReturnedElementsTraining(self): + self._TestReturnElements(True) + + def testReturnedElementsEval(self): + self._TestReturnElements(False) + + def _TestReturnElements(self, is_training): graph = ops.Graph() with graph.as_default(): a = constant_op.constant(1.0) b = variables.Variable(2.0) c = a + b elements = [a, b, c.op] - for element in elements: - print(element) - q_graph, returned_elements = quantize_graph.create_training_graph( - graph, elements=elements) + if is_training: + q_graph, returned_elements = quantize_graph.create_training_graph( + graph, elements=elements) + else: + q_graph, returned_elements = quantize_graph.create_eval_graph( + graph, elements=elements) # Make sure q_graph is different from graph. self.assertTrue(graph != q_graph) # Check that the returned elements are part of the new graph. @@ -50,25 +62,79 @@ class QuantizeTest(test_util.TensorFlowTestCase): for element, returned_element in zip(elements, returned_elements): self.assertEqual(element.name, returned_element.name) - # We have a lot of other tests that test the details of the rewrite, here we - # just the specific features of the quantize_graph API. - def testReturnedElementsEval(self): + def testNoReturnElementsTraining(self): + self._TestNoReturnElements(True) + + def testNoReturnElementsEval(self): + self._TestNoReturnElements(False) + + def _TestNoReturnElements(self, is_training): graph = ops.Graph() with graph.as_default(): a = constant_op.constant(1.0) b = variables.Variable(2.0) - c = a + b - elements = [a, b, c.op] - q_graph, returned_elements = quantize_graph.create_eval_graph( - graph, elements=elements) + _ = a + b + if is_training: + q_graph = quantize_graph.create_training_graph(graph) + else: + q_graph = quantize_graph.create_eval_graph(graph) + # Check that quantize_graph didn't return a tuple when elements isn't + # provided. + self.assertTrue(isinstance(q_graph, ops.Graph)) # Make sure q_graph is different from graph. self.assertTrue(graph != q_graph) - # Check that the returned elements are part of the new graph. - for returned_element in returned_elements: - self.assertEqual(q_graph, returned_element.graph) - # Check that the elements match with the one from the input graph. - for element, returned_element in zip(elements, returned_elements): - self.assertEqual(element.name, returned_element.name) + + def testDeviceNameTraining(self): + self._TestDeviceName(True) + + def testDeviceNameEval(self): + self._TestDeviceName(False) + + def _TestDeviceName(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + conv = layers.conv2d( + inputs, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test') + _ = nn_ops.relu6(conv) + + device_name = '/job:oink/task:0/device:CPU:0' + if is_training: + q_graph = quantize_graph.create_training_graph( + graph, device_name_or_function=device_name) + else: + q_graph = quantize_graph.create_eval_graph( + graph, device_name_or_function=device_name) + + orig_variable_names = set( + [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + q_variables = q_graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # Ensure that variables were added. + self.assertTrue(len(orig_variable_names) < len(q_variables)) + # All added variables should have the specified device name. + for var in q_variables: + if var.name not in orig_variable_names: + self.assertEqual(var.device, device_name) + + def _WeightInit(self, stddev): + """Returns truncated normal variable initializer. + + Function is defined purely to shorten the name so that it stops wrapping. + + Args: + stddev: Standard deviation of normal variable. + + Returns: + An initialized that initialzes with a truncated normal variable. + """ + return init_ops.truncated_normal_initializer(stddev=stddev) if __name__ == '__main__': diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index b5a32a7266a4c3ddf9a481fd9b292ab0f1812a9a..3e62f95bd63db3134ba0b96c46b4a92aa73ebef9 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.contrib.quantize.python import quantize from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -35,18 +36,11 @@ conv2d = layers.conv2d fully_connected = layers.fully_connected separable_conv2d = layers.separable_conv2d -_DEFAULT_BATCH_NORM_PARAMS = { - 'center': True, - 'scale': True, - 'decay': 1.0 - 0.003, - 'fused': False, -} - -# TODO(suharshs): Use parameterized test once OSS TF supports it. class QuantizeTest(test_util.TensorFlowTestCase): - def _RunTestOverParameters(self, test_fn): + def _RunWithoutBatchNormTestOverParameters(self, test_fn): + # TODO(suharshs): Use parameterized test once OSS TF supports it. parameters_list = [ # (activation, activation_op_name, with_bypass, delay) (nn_ops.relu6, 'Relu6', False, None), @@ -60,10 +54,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): (array_ops.identity, 'Identity', True, None), (nn_ops.relu6, 'Relu6', True, 5000), (nn_ops.relu, 'Relu', True, 5000), - (array_ops.identity, 'Identity', True, 5000) + (array_ops.identity, 'Identity', True, 5000), ] - for parameters in parameters_list: - test_fn(parameters[0], parameters[1], parameters[2], parameters[3]) + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3]) def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, with_bypass, delay): @@ -107,7 +101,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): scope + '/weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + '/convolution' + output_op_name = scope + '/Conv2D' self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -137,7 +131,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def testQuantize_Conv2dWithoutBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_Conv2dWithoutBatchNorm) + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_Conv2dWithoutBatchNorm) def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name, with_bypass, delay): @@ -210,7 +205,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def testQuantize_FCWithoutBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_FCWithoutBatchNorm) + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_FCWithoutBatchNorm) def _TestQuantize_DepthwiseConv2dWithoutBatchNorm( self, activation, activation_op_name, with_bypass, delay): @@ -284,11 +280,43 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def testQuantize_DepthwiseConv2dWithoutBatchNorm(self): - self._RunTestOverParameters( + self._RunWithoutBatchNormTestOverParameters( self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) + def _RunBatchNormTestOverParameters(self, test_fn): + # TODO(suharshs): Use parameterized test once OSS TF supports it. + parameters_list = [ + # (activation, activation_op_name, with_bypass, delay, fused_batch_norm) + (nn_ops.relu6, 'Relu6', False, None, False), + (nn_ops.relu, 'Relu', False, None, False), + (array_ops.identity, 'Identity', False, None, False), + (nn_ops.relu6, 'Relu6', False, 5000, False), + (nn_ops.relu, 'Relu', False, 5000, False), + (array_ops.identity, 'Identity', False, 5000, False), + (nn_ops.relu6, 'Relu6', True, None, False), + (nn_ops.relu, 'Relu', True, None, False), + (array_ops.identity, 'Identity', True, None, False), + (nn_ops.relu6, 'Relu6', True, 5000, False), + (nn_ops.relu, 'Relu', True, 5000, False), + (array_ops.identity, 'Identity', True, 5000, False), + (nn_ops.relu6, 'Relu6', False, None, True), + (nn_ops.relu, 'Relu', False, None, True), + (array_ops.identity, 'Identity', False, None, True), + (nn_ops.relu6, 'Relu6', False, 5000, True), + (nn_ops.relu, 'Relu', False, 5000, True), + (array_ops.identity, 'Identity', False, 5000, True), + (nn_ops.relu6, 'Relu6', True, None, True), + (nn_ops.relu, 'Relu', True, None, True), + (array_ops.identity, 'Identity', True, None, True), + (nn_ops.relu6, 'Relu6', True, 5000, True), + (nn_ops.relu, 'Relu', True, 5000, True), + (array_ops.identity, 'Identity', True, 5000, True) + ] + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3], params[4]) + def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay): + with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> Conv2d with batch norm -> Activation. Args: @@ -298,25 +326,29 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. """ self._testQuantize_Conv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=True) self._testQuantize_Conv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=False) def testQuantize_Conv2dWithBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) + self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, use_ema): + with_bypass, delay, fused_batch_norm, + use_ema): """Tests quantization: inputs -> Conv2d with batch norm -> Activation. Args: @@ -326,6 +358,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() @@ -337,39 +370,29 @@ class QuantizeTest(test_util.TensorFlowTestCase): stride = 1 if with_bypass else 2 out_depth = 3 if with_bypass else 32 scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=None, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) - # Manually fold the batch norm. - weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0] - bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') - .outputs[0]) - mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold') - stride = [stride, stride] - conv_fold = nn_ops.convolution( - input=inputs, - filter=mul_fold, + node = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, padding='SAME', - strides=stride, - data_format='NHWC', - name=scope + '/convolution_Fold') - bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') - .outputs[0]) - add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold') + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + # Manually add a bypass (optionaly) and an activation. if with_bypass: - node = math_ops.add(inputs, add_fold, name='test/Add') - else: - node = add_fold + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') + fold_batch_norms.FoldBatchNorms(graph) + quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) @@ -384,7 +407,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if (delay and use_ema) else '/convolution_Fold') + if (delay and use_ema) else '/Conv2D_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -413,7 +436,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay): + with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> FC with batch norm -> Activation. Args: @@ -423,25 +446,29 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. """ self._testQuantize_FCWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=True) self._testQuantize_FCWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=False) def testQuantize_FCWithBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_FCWithBatchNorm) + self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, use_ema): + with_bypass, delay, fused_batch_norm, + use_ema): """Tests quantization: inputs -> FC with batch norm -> Activation. Args: @@ -451,6 +478,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() @@ -461,32 +489,27 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs = array_ops.zeros((batch_size, depth)) out_depth = 256 if with_bypass else 128 scope = 'test/test2' if with_bypass else 'test' - node = fully_connected(inputs, out_depth, - weights_initializer=self._WeightInit(0.03), - activation_fn=None, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) - # Manually fold the batch norm. - weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0] - bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') - .outputs[0]) - mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold') - fc_fold = math_ops.matmul(inputs, mul_fold, name=scope + '/MatMul_Fold') - bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') - .outputs[0]) - add_fold = math_ops.add(fc_fold, bn_bias, name=scope + '/add_fold') + node = fully_connected( + inputs, + out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + # Manually add a bypass (optionaly) and an activation. if with_bypass: - node = math_ops.add(inputs, add_fold, name='test/Add') - else: - node = add_fold + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') + fold_batch_norms.FoldBatchNorms(graph) + quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) @@ -530,7 +553,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def _TestQuantize_DepthwiseConv2dWithBatchNorm( - self, activation, activation_op_name, with_bypass, delay): + self, activation, activation_op_name, with_bypass, delay, + fused_batch_norm): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: @@ -540,26 +564,30 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. """ self._testQuantize_DepthwiseConv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=True) self._testQuantize_DepthwiseConv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=False) def testQuantize_DepthwiseConv2dWithBatchNorm(self): - self._RunTestOverParameters( - self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) + self._RunBatchNormTestOverParameters( + self._TestQuantize_DepthwiseConv2dWithBatchNorm) def _testQuantize_DepthwiseConv2dWithBatchNorm( - self, activation, activation_op_name, with_bypass, delay, use_ema): + self, activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: @@ -569,6 +597,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() @@ -579,46 +608,30 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 scope = 'test/test2' if with_bypass else 'test' - node = separable_conv2d(inputs, None, [5, 5], stride=stride, - depth_multiplier=1.0, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=None, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) - # Manually fold the batch norm. - weights = (graph.get_operation_by_name(scope + '/depthwise_weights/read') - .outputs[0]) - bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') - .outputs[0]) - new_shape = [ - weights.get_shape().as_list()[2], weights.get_shape().as_list()[3] - ] - bn_mult_reshaped = array_ops.reshape( - bn_mult, new_shape, name=scope + '/gamma_reshape') - mul_fold = math_ops.multiply( - weights, bn_mult_reshaped, name=scope + '/mul_fold') - stride = [1, stride, stride, 1] - conv_fold = nn_ops.depthwise_conv2d( - input=inputs, - filter=mul_fold, + node = separable_conv2d( + inputs, + None, [5, 5], + stride=stride, + depth_multiplier=1.0, padding='SAME', - strides=stride, - name=scope + '/depthwise_Fold') - bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') - .outputs[0]) - add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold') + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + # Manually add a bypass (optionaly) and an activation. if with_bypass: - node = math_ops.add(inputs, add_fold, name='test/Add') - else: - node = add_fold + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') + fold_batch_norms.FoldBatchNorms(graph) + quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) quantization_node_name = 'FakeQuantWithMinMaxVars' @@ -660,6 +673,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + def _BatchNormParams(self, fused=False): + return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused} + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index a6bd809bb7de0b674671d09e4a941675976ce8ab..4a82eac1978cf834732e339e4e76a4507b9a090c 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -65,28 +65,5 @@ class QuantizeTest(test_util.TensorFlowTestCase): """ return init_ops.truncated_normal_initializer(stddev=stddev) - def _AssertInputOpsAre(self, op, in_op_names): - """Asserts that all inputs to op come from in_op_names (disregarding order). - - Args: - op: Operation to check inputs for. - in_op_names: List of strings, operations where all op's inputs should - come from. - """ - expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names] - self.assertItemsEqual([t.name for t in op.inputs], expected_inputs) - - def _AssertOutputGoesToOps(self, op, graph, out_op_names): - """Asserts that outputs from op go to out_op_names (and perhaps others). - - Args: - op: Operation to check outputs for. - graph: Graph where output operations are located. - out_op_names: List of strings, operations where op's outputs should go. - """ - for out_op_name in out_op_names: - out_op = graph.get_operation_by_name(out_op_name) - self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) - if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py index db190a1a41668bff3f6db1c674192980db068838..8b34465d21d14508c24056b588f2533d8fea6a1d 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py @@ -27,13 +27,15 @@ import math from tensorflow.contrib.receptive_field.python.util import graph_compute_order from tensorflow.contrib.util import make_ndarray from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.framework import ops as framework_ops +import numpy as np # White-listed layer operations, which do not affect the receptive field # computation. _UNCHANGED_RF_LAYER_OPS = [ - "Softplus", "Relu", "BiasAdd", "Mul", "Add", "Const", "Identity", - "VariableV2", "Sub", "Rsqrt", "ConcatV2" -] + 'Add', 'BiasAdd', 'Ceil', 'ConcatV2', 'Const', 'Floor', 'Identity', 'Log', + 'Mul', 'Pow', 'RealDiv', 'Relu', 'Round', 'Rsqrt', 'Softplus', 'Sub', + 'VariableV2'] # Different ways in which padding modes may be spelled. _VALID_PADDING = ["VALID", b"VALID"] @@ -238,7 +240,8 @@ def _get_layer_params(node, name_to_order_node): padding_x = 0 padding_y = 0 else: - raise ValueError("Unknown layer op: %s" % node.op) + raise ValueError("Unknown layer for operation '%s': %s" % + (node.name, node.op)) return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y @@ -304,13 +307,103 @@ def _get_effective_padding_node_input(stride, padding, return stride * effective_padding_output + padding -def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): - """Computes receptive field (RF) parameters from a GraphDef object. +class ReceptiveField: + """ + Receptive field of a convolutional neural network. + + Args: + size: Receptive field size. + stride: Effective stride. + padding: Effective padding. + """ + def __init__(self, size, stride, padding): + self.size = np.asarray(size) + self.stride = np.asarray(stride) + self.padding = np.asarray(padding) + + def compute_input_center_coordinates(self, y, axis=None): + """ + Computes the center of the receptive field that generated a feature. + + Args: + y: An array of feature coordinates with shape `(..., d)`, where `d` is the + number of dimensions of the coordinates. + axis: The dimensions for which to compute the input center coordinates. + If `None` (the default), compute the input center coordinates for all + dimensions. + + Returns: + x: Center of the receptive field that generated the features, at the input + of the network. + + Raises: + ValueError: If the number of dimensions of the feature coordinates does + not match the number of elements in `axis`. + """ + # Use all dimensions. + if axis is None: + axis = range(self.size.size) + # Ensure axis is a list because tuples have different indexing behavior. + axis = list(axis) + y = np.asarray(y) + if y.shape[-1] != len(axis): + raise ValueError("Dimensionality of the feature coordinates `y` (%d) " + "does not match dimensionality of `axis` (%d)" % + (y.shape[-1], len(axis))) + return - self.padding[axis] + y * self.stride[axis] + \ + (self.size[axis] - 1) / 2 + + def compute_feature_coordinates(self, x, axis=None): + """ + Computes the position of a feature given the center of a receptive field. + + Args: + x: An array of input center coordinates with shape `(..., d)`, where `d` + is the number of dimensions of the coordinates. + axis: The dimensions for which to compute the feature coordinates. + If `None` (the default), compute the feature coordinates for all + dimensions. + + Returns: + y: Coordinates of the features. + + Raises: + ValueError: If the number of dimensions of the input center coordinates + does not match the number of elements in `axis`. + """ + # Use all dimensions. + if axis is None: + axis = range(self.size.size) + # Ensure axis is a list because tuples have different indexing behavior. + axis = list(axis) + x = np.asarray(x) + if x.shape[-1] != len(axis): + raise ValueError("Dimensionality of the input center coordinates `x` " + "(%d) does not match dimensionality of `axis` (%d)" % + (x.shape[-1], len(axis))) + return (x + self.padding[axis] + (1 - self.size[axis]) / 2) / \ + self.stride[axis] + + def __iter__(self): + return iter(np.concatenate([self.size, self.stride, self.padding])) + + +def compute_receptive_field_from_graph_def(graph_def, input_node, output_node, + stop_propagation=None): + """Computes receptive field (RF) parameters from a Graph or GraphDef object. + + The algorithm stops the calculation of the receptive field whenever it + encounters an operation in the list `stop_propagation`. Stopping the + calculation early can be useful to calculate the receptive field of a + subgraph such as a single branch of the + [inception network](https://arxiv.org/abs/1512.00567). Args: - graph_def: GraphDef object. - input_node: Name of the input node from graph. - output_node: Name of the output node from graph. + graph_def: Graph or GraphDef object. + input_node: Name of the input node or Tensor object from graph. + output_node: Name of the output node or Tensor object from graph. + stop_propagation: List of operation or scope names for which to stop the + propagation of the receptive field. Returns: rf_size_x: Receptive field size of network in the horizontal direction, with @@ -331,6 +424,18 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): cannot be found. For network criterion alignment, see photos/vision/features/delf/g3doc/rf_computation.md """ + # Convert a graph to graph_def if necessary. + if isinstance(graph_def, framework_ops.Graph): + graph_def = graph_def.as_graph_def() + + # Convert tensors to names. + if isinstance(input_node, framework_ops.Tensor): + input_node = input_node.op.name + if isinstance(output_node, framework_ops.Tensor): + output_node = output_node.op.name + + stop_propagation = stop_propagation or [] + # Computes order of computation for a given graph. name_to_order_node = graph_compute_order.get_compute_order( graph_def=graph_def) @@ -422,6 +527,10 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): # Loop over this node's inputs and potentially propagate information down. for inp_name in node.input: + # Stop the propagation of the receptive field. + if any(inp_name.startswith(stop) for stop in stop_propagation): + logging.vlog(3, "Skipping explicitly ignored node %s.", node.name) + continue logging.vlog(4, "inp_name = %s", inp_name) inp_node = name_to_order_node[inp_name].node logging.vlog(4, "inp_node = \n%s", inp_node) @@ -480,6 +589,7 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): raise ValueError("Output node was not found") if input_node not in rf_sizes_x: raise ValueError("Input node was not found") - return (rf_sizes_x[input_node], rf_sizes_y[input_node], - effective_strides_x[input_node], effective_strides_y[input_node], - effective_paddings_x[input_node], effective_paddings_y[input_node]) + return ReceptiveField( + (rf_sizes_x[input_node], rf_sizes_y[input_node]), + (effective_strides_x[input_node], effective_strides_y[input_node]), + (effective_paddings_x[input_node], effective_paddings_y[input_node])) diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py index 2771389250b1518f33ebadf3f1cfd23e653dab93..8d7d5440f630a3a78749e04a5eb058d637c258fc 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn from tensorflow.python.platform import test +import numpy as np def create_test_network_1(): @@ -150,6 +151,31 @@ def create_test_network_5(): return g +def create_test_network_6(): + """Aligned network with dropout for test. + + The graph is similar to create_test_network_1(), except that the right branch + has dropout normalization. + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Left branch. + l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID') + # Right branch. + l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]]) + l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID') + l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID') + dropout = slim.dropout(l3) + # Addition. + nn.relu(l1 + dropout, name='output') + return g + + class RfUtilsTest(test.TestCase): def testComputeRFFromGraphDefAligned(self): @@ -220,6 +246,36 @@ class RfUtilsTest(test.TestCase): self.assertEqual(effective_padding_x, 0) self.assertEqual(effective_padding_y, 0) + def testComputeRFFromGraphDefStopPropagation(self): + graph_def = create_test_network_6().as_graph_def() + input_node = 'input_image' + output_node = 'output' + # Compute the receptive field but stop the propagation for the random + # uniform variable of the dropout. + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y) = ( + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node, + ['Dropout/dropout/random_uniform'])) + self.assertEqual(receptive_field_x, 3) + self.assertEqual(receptive_field_y, 3) + self.assertEqual(effective_stride_x, 4) + self.assertEqual(effective_stride_y, 4) + self.assertEqual(effective_padding_x, 1) + self.assertEqual(effective_padding_y, 1) + + def testComputeCoordinatesRoundtrip(self): + graph_def = create_test_network_1() + input_node = 'input_image' + output_node = 'output' + rf = receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node) + + x = np.random.randint(0, 100, (50, 2)) + y = rf.compute_feature_coordinates(x) + x2 = rf.compute_input_center_coordinates(y) + + self.assertAllEqual(x, x2) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 3e6c09662fe8b54ff4c07175cbba99b87e27969c..29ba26d75dcce6ac1983f82dc2dfc03323e0ec5f 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -24,6 +24,22 @@ load( "tf_kernel_tests_linkstatic", ) +cc_library( + name = "all_ops", + deps = [ + ":gru_ops_op_lib", + ":lstm_ops_op_lib", + ], +) + +cc_library( + name = "all_kernels", + deps = [ + ":gru_ops_kernels", + ":lstm_ops_kernels", + ], +) + tf_custom_op_py_library( name = "rnn_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]) + [ @@ -34,14 +50,13 @@ tf_custom_op_py_library( ":python/ops/_lstm_ops.so", ], kernels = [ - ":gru_ops_kernels", - ":lstm_ops_kernels", - ":gru_ops_op_lib", - ":lstm_ops_op_lib", + ":all_ops", + ":all_kernels", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":benchmarking", ":gru_ops", ":lstm_ops", "//tensorflow/contrib/compiler:compiler_py", @@ -125,6 +140,9 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = [ + "optonly", + ], ) cuda_py_tests( @@ -138,6 +156,7 @@ cuda_py_tests( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -147,6 +166,7 @@ cuda_py_tests( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/eager:context", ], shard_count = 10, ) @@ -259,6 +279,7 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = ["no_oss"], ) tf_cc_test( @@ -386,3 +407,13 @@ py_test( "//tensorflow/python:variables", ], ) + +py_library( + name = "benchmarking", + srcs = ["python/kernel_tests/benchmarking.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc index ffeb9953c576b76a613e1ba46ff087c184acd532..941a457fd3ada312b981fb23c769ff9ecea9ff13 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc @@ -41,6 +41,142 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { +template +void LSTMBlockCellFpropWithEigen( + const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d, + const T forget_bias, const T cell_clip, bool use_peephole, + typename TTypes::ConstMatrix x, typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::Matrix xh, typename TTypes::Matrix i, + typename TTypes::Matrix cs, typename TTypes::Matrix f, + typename TTypes::Matrix o, typename TTypes::Matrix ci, + typename TTypes::Matrix co, typename TTypes::Matrix icfo, + typename TTypes::Matrix h) { + // Concat xh = [x, h]. + xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x; + xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev; + + // states1 = xh * w + b + typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); + TensorBlasGemm::compute( + ctx, d, false, false, T(1), const_xh, w, T(0), icfo); + Eigen::array b_shape({1, b.dimensions()[0]}); + Eigen::array broadcast_shape({cell.batch_size(), 1}); + icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape); + + Eigen::array p_shape({1, cell.cell_size()}); + Eigen::array p_broadcast_shape({cell.batch_size(), 1}); + + // Input gate. + if (use_peephole) { + auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape); + i.device(d) = + (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep) + .sigmoid(); + } else { + i.device(d) = + icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid(); + } + + // Cell input. + ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh(); + + // Forget gate (w/ bias). + if (use_peephole) { + auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + + f.constant(forget_bias) + f_peep) + .sigmoid(); + } else { + f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + + f.constant(forget_bias)) + .sigmoid(); + } + + // cs = ci .* i + f .* cs_prev + cs.device(d) = i * ci + f * cs_prev; + + if (cell_clip > 0.0f) { + cs.device(d) = + cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op()); + } + + // co = tanh(cs) + co.device(d) = cs.tanh(); + + // Output gate. + if (use_peephole) { + auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape); + o.device(d) = + (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep) + .sigmoid(); + } else { + o.device(d) = + icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid(); + } + + // h = o .* co + h.device(d) = o * co; +} + +template +void LSTMBlockCellBpropWithEigen( + const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d, + bool use_peephole, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::ConstMatrix i, typename TTypes::ConstMatrix cs, + typename TTypes::ConstMatrix f, typename TTypes::ConstMatrix o, + typename TTypes::ConstMatrix ci, typename TTypes::ConstMatrix co, + typename TTypes::ConstMatrix cs_grad, + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, + typename TTypes::Matrix df, typename TTypes::Matrix di, + typename TTypes::Matrix dicfo, typename TTypes::Matrix cs_prev_grad, + typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, + typename TTypes::Vec wco_grad) { + // do[t] = sigm'(o[t]) .* dh[t] .* co[t] + do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; + + // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] + dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; + + Eigen::array p_shape({1, cell.cell_size()}); + Eigen::array p_broadcast_shape({cell.batch_size(), 1}); + if (use_peephole) { + dcs.device(d) = + dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); + } + + // dci[t] = tanh'(ci[t]) dcs[t] i[t] + dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; + + // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] + df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; + + // di[t] = sigm'(i[t]) dcs[t] ci[t] + di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; + + dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di; + dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci; + dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df; + dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_; + + cs_prev_grad.device(d) = dcs * f; + if (use_peephole) { + cs_prev_grad.device(d) = + cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + + df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + wci_grad.device(d) = (di * cs_prev).sum(Eigen::array({0})); + wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array({0})); + wco_grad.device(d) = (do_ * cs).sum(Eigen::array({0})); + } +} + #define DEFINE_CPU_SPECS(T) \ template <> \ void LSTMBlockCellFprop::operator()( \ @@ -55,7 +191,7 @@ namespace functor { typename TTypes::Matrix f, typename TTypes::Matrix o, \ typename TTypes::Matrix ci, typename TTypes::Matrix co, \ typename TTypes::Matrix icfo, typename TTypes::Matrix h) { \ - LSTMBlockCellFpropWithEigen( \ + LSTMBlockCellFpropWithEigen( \ *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \ h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h); \ } \ diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h index 30a4b447068bf983714aea0f8da6dfecb96a5c7b..1906581b16b2e76243320bc67c8ac831323fb8e7 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.h +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h @@ -169,88 +169,6 @@ struct LSTMBlockCellFprop : public LSTMBlockCell { typename TTypes::Matrix h); }; -// TODO(b/63339763): Once GPUDevice implementation no longer relies on Eigen, -// move into lstm_ops.cc. -template -void LSTMBlockCellFpropWithEigen( - const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d, - const T forget_bias, const T cell_clip, bool use_peephole, - typename TTypes::ConstMatrix x, typename TTypes::ConstMatrix cs_prev, - typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, - typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, - typename TTypes::ConstVec wco, typename TTypes::ConstVec b, - typename TTypes::Matrix xh, typename TTypes::Matrix i, - typename TTypes::Matrix cs, typename TTypes::Matrix f, - typename TTypes::Matrix o, typename TTypes::Matrix ci, - typename TTypes::Matrix co, typename TTypes::Matrix icfo, - typename TTypes::Matrix h) { - // Concat xh = [x, h]. - xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x; - xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev; - - // states1 = xh * w + b - typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); - TensorBlasGemm::compute(ctx, d, false, false, T(1), - const_xh, w, T(0), icfo); - Eigen::array b_shape({1, b.dimensions()[0]}); - Eigen::array broadcast_shape({cell.batch_size(), 1}); - icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape); - - Eigen::array p_shape({1, cell.cell_size()}); - Eigen::array p_broadcast_shape({cell.batch_size(), 1}); - - // Input gate. - if (use_peephole) { - auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape); - i.device(d) = - (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep) - .sigmoid(); - } else { - i.device(d) = - icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid(); - } - - // Cell input. - ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh(); - - // Forget gate (w/ bias). - if (use_peephole) { - auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); - f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + - f.constant(forget_bias) + f_peep) - .sigmoid(); - } else { - f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + - f.constant(forget_bias)) - .sigmoid(); - } - - // cs = ci .* i + f .* cs_prev - cs.device(d) = i * ci + f * cs_prev; - - if (cell_clip > 0.0f) { - cs.device(d) = - cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op()); - } - - // co = tanh(cs) - co.device(d) = cs.tanh(); - - // Output gate. - if (use_peephole) { - auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape); - o.device(d) = - (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep) - .sigmoid(); - } else { - o.device(d) = - icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid(); - } - - // h = o .* co - h.device(d) = o * co; -} - // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for // GPUDevice implementation. template @@ -278,64 +196,6 @@ struct LSTMBlockCellBprop : public LSTMBlockCell { typename TTypes::Vec wco_grad); }; -// TODO(b/63339763): Once GPUDevice implementation no longer relies on Eigen, -// move into lstm_ops.cc. -template -void LSTMBlockCellBpropWithEigen( - const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d, - bool use_peephole, typename TTypes::ConstMatrix x, - typename TTypes::ConstMatrix cs_prev, - typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, - typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, - typename TTypes::ConstVec wco, typename TTypes::ConstVec b, - typename TTypes::ConstMatrix i, typename TTypes::ConstMatrix cs, - typename TTypes::ConstMatrix f, typename TTypes::ConstMatrix o, - typename TTypes::ConstMatrix ci, typename TTypes::ConstMatrix co, - typename TTypes::ConstMatrix cs_grad, - typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, - typename TTypes::Matrix dcs, typename TTypes::Matrix dci, - typename TTypes::Matrix df, typename TTypes::Matrix di, - typename TTypes::Matrix dicfo, typename TTypes::Matrix cs_prev_grad, - typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, - typename TTypes::Vec wco_grad) { - // do[t] = sigm'(o[t]) .* dh[t] .* co[t] - do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; - - // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] - dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; - - Eigen::array p_shape({1, cell.cell_size()}); - Eigen::array p_broadcast_shape({cell.batch_size(), 1}); - if (use_peephole) { - dcs.device(d) = - dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); - } - - // dci[t] = tanh'(ci[t]) dcs[t] i[t] - dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; - - // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] - df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; - - // di[t] = sigm'(i[t]) dcs[t] ci[t] - di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; - - dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di; - dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci; - dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df; - dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_; - - cs_prev_grad.device(d) = dcs * f; - if (use_peephole) { - cs_prev_grad.device(d) = - cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + - df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); - wci_grad.device(d) = (di * cs_prev).sum(Eigen::array({0})); - wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array({0})); - wco_grad.device(d) = (do_ * cs).sum(Eigen::array({0})); - } -} - template struct BlockLSTMBprop : public LSTMBlockCell { BlockLSTMBprop(const int batch_size, const int input_size, diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc index e18f8079a384b97bde6223c06944b6c5226bee21..6d3758fef15e7130b740a377d8bcd41d31203299 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -20,15 +20,334 @@ limitations under the License. #include "tensorflow/contrib/rnn/kernels/lstm_ops.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/eigen_activations.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { namespace functor { typedef Eigen::GpuDevice GPUDevice; -// TODO(b/63339763): Provide an alternative implementation for -// LSTMBlockCell{F,B}prop that doesn't rely on Eigen. +namespace { + +// Adds bias, applies non-linearities and gates. +// +// Launch with a 2D setup such that there is one thread per (example, +// activation) with 'x' governing example index and 'y' governing activation. +// +// Launch with blocks of (batch x 32) +// +// TODO(b/67600500): Try making 'use_peephole' a template parameter. +template +__global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev, + const T* wci, const T* wcf, const T* wco, T* o, T* h, + T* ci, T* cs, T* co, T* i, T* f, const T forget_bias, + const T cell_clip, const int batch_size, + const int cell_size) { + const int batch_id = blockIdx.x * blockDim.x + threadIdx.x; + const int act_id = blockIdx.y * blockDim.y + threadIdx.y; + + if (batch_id >= batch_size || act_id >= cell_size) return; + + // The following code assumes the input arrays are of the following + // shapes and interpretations. + // + // 1) 'icfo' is a matrix such that, + // + // cell_size cell_size cell_size cell_size + // +----------+----------+----------+----------+ + // | | | | | + // | i | c | f | o | batch_size + // | | | | | + // +----------+----------+----------+----------+ + // + // 'gid' is the index assigned to this thread for 'icfo' in the 'i' submatrix. + // + // 2) 'b' is a vector such that, + // + // cell_size cell_size cell_size cell_size + // +----------+----------+----------+----------+ + // | i | c | f | o | 1 + // +----------+----------+----------+----------+ + // + // 'act_id' is the index assigned to this thread for 'b' in the 'i' subvector. + // + // 3) 'wc{i,f,o}' are vectors such that, + // + // cell_size + // +----------+ + // | i | 1 + // +----------+ + // + // 'act_id' is the index to this thread. + // + // 4) All other matrices have the form, + // + // cell_size + // +----------+ + // | | + // | i | batch_size + // | | + // +----------+ + // + // 'cid' is the index assigned to this thread. + // + const int gid = batch_id * cell_size * 4 + act_id; + const int cid = batch_id * cell_size + act_id; + Eigen::internal::scalar_sigmoid_op sigmoid_op; + Eigen::internal::scalar_tanh_op tanh_op; + Eigen::scalar_clip_op clip_op; + + T i_local; + if (use_peephole) { + i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id] + + cs_prev[cid] * wci[act_id]); + } else { + i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id]); + } + i[cid] = i_local; + + const T ci_local = + tanh_op(icfo[1 * cell_size + gid] + b[1 * cell_size + act_id]); + ci[cid] = ci_local; + + T f_local; + if (use_peephole) { + f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] + + forget_bias + cs_prev[cid] * wcf[act_id]); + } else { + f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] + + forget_bias); + } + f[cid] = f_local; + + T cs_local = i_local * ci_local + f_local * cs_prev[cid]; + if (cell_clip > 0.0) { + cs_local = clip_op(cs_local, cell_clip); + } + cs[cid] = cs_local; + + const T co_local = tanh_op(cs_local); + co[cid] = co_local; + + T o_local; + if (use_peephole) { + o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id] + + cs_local * wco[act_id]); + } else { + o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id]); + } + o[cid] = o_local; + + h[cid] = o_local * co_local; +} + +// Concatenate 'x' and 'h' and copy their contents into 'xh'. +template +__global__ void concat_xh(T* xh, const T* x, const T* h_prev, + const int batch_size, const int cell_size, + const int input_size) { + // Assumes 'x', 'h', and 'xh' are of the following shape, + // + // input_size cell_size + // +----------+----------+ + // | | | + // | x | h | batch_size + // | | | + // +----------+----------+ + // + const int gid = blockDim.x * blockIdx.x + threadIdx.x; + const int width = input_size + cell_size; + + if (gid >= width * batch_size) return; + + const int output_row = gid / width; + const int output_col = gid % width; + + if (output_col < input_size) { // x + xh[gid] = x[output_row * input_size + output_col]; + } else { // h + xh[gid] = h_prev[output_row * cell_size + output_col - input_size]; + } +} + +template +void LSTMBlockCellFpropWithCUDA( + OpKernelContext* ctx, const GPUDevice& d, const T forget_bias, + const T cell_clip, bool use_peephole, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::Matrix xh, typename TTypes::Matrix i, + typename TTypes::Matrix cs, typename TTypes::Matrix f, + typename TTypes::Matrix o, typename TTypes::Matrix ci, + typename TTypes::Matrix co, typename TTypes::Matrix icfo, + typename TTypes::Matrix h, int batch_size, int cell_size, + int input_size) { + const cudaStream_t& cu_stream = GetCudaStream(ctx); + + // Concatenate xh = [x, h]. + // + // Each block is assigned 128 threads. Good values are in [128, 1024] and are + // divisible by 32 (the size of a warp). The number of blocks is such that + // there are enough to process all the data. + const int block_dim = 128; + const int grid_dim = + Eigen::divup(batch_size * (cell_size + input_size), block_dim); + concat_xh<<>>( + xh.data(), x.data(), h_prev.data(), batch_size, cell_size, input_size); + + // states1 = xh * w + typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); + TensorBlasGemm::compute( + ctx, d, false, false, T(1), const_xh, w, T(0), icfo); + + // Add bias, apply non-linearities and gating. + // + // Use 2D blocks. The number of threads per block is equal to x * y, where x = + // min(batch_size, 8) and y = 32. See above for guidance on number of + // threads. + dim3 block_dim_2d(std::min(batch_size, 8), 32); + dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast(block_dim_2d.x)), + Eigen::divup(cell_size, static_cast(block_dim_2d.y))); + + if (use_peephole) { + lstm_gates<<>>( + icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(), + wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(), + i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size); + } else { + lstm_gates<<>>( + icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(), + wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(), + i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size); + } +} + +template +__global__ void lstm_gates_bprop( + const T* cs_prev, // [batch_size, cell_size] + const T* h_prev, // [batch_size, cell_size] + const T* w, // [input_size + cell_size, 4 * cell_size] + const T* wci, // [cell_size] + const T* wcf, // [cell_size] + const T* wco, // [cell_size] + const T* b, // [4 * cell_size] + const T* i, // [batch_size, cell_size] + const T* cs, // [batch_size, cell_size] + const T* f, // [batch_size, cell_size] + const T* o, // [batch_size, cell_size] + const T* ci, // [batch_size, cell_size] + const T* co, // [batch_size, cell_size] + const T* cs_grad, // [batch_size, cell_size] + const T* h_grad, // [batch_size, cell_size] + T* do_, // [batch_size, cell_size] + T* dcs, // [batch_size, cell_size] + T* dci, // [batch_size, cell_size] + T* df, // [batch_size, cell_size] + T* di, // [batch_size, cell_size] + T* dicfo, // [input_size + cell_size, 4 * cell_size] + T* cs_prev_grad, // [batch_size, cell_size] + const int batch_size, const int cell_size, const bool use_peephole) { + const int batch_id = blockIdx.x * blockDim.x + threadIdx.x; + const int act_id = blockIdx.y * blockDim.y + threadIdx.y; + + if (batch_id >= batch_size || act_id >= cell_size) return; + + const int gid = batch_id * cell_size * 4 + act_id; + const int cid = batch_id * cell_size + act_id; + + const T one = static_cast(1.0f); + + // do[t] = sigm'(o[t]) .* dh[t] .* co[t] + const T o_local = o[cid]; + const T h_grad_local = h_grad[cid]; + const T co_local = co[cid]; + const T ci_local = ci[cid]; + const T do_local = o_local * (one - o_local) * h_grad_local * co_local; + const T i_local = i[cid]; + const T f_local = f[cid]; + + do_[cid] = do_local; + + // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] + T dcs_local = + (one - co_local * co_local) * h_grad_local * o_local + cs_grad[cid]; + if (use_peephole) { + dcs_local += do_local * wco[act_id]; + } + dcs[cid] = dcs_local; + + // dci[t] = tanh'(ci[t]) dcs[t] i[t] + const T dci_local = (one - ci_local * ci_local) * dcs_local * i_local; + dci[cid] = dci_local; + + // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] + const T df_local = f_local * (one - f_local) * dcs_local * cs_prev[cid]; + df[cid] = df_local; + + // di[t] = sigm'(i[t]) dcs[t] ci[t] + const T di_local = i_local * (one - i_local) * dcs_local * ci_local; + di[cid] = di_local; + + dicfo[gid + 0 * cell_size] = di_local; + dicfo[gid + 1 * cell_size] = dci_local; + dicfo[gid + 2 * cell_size] = df_local; + dicfo[gid + 3 * cell_size] = do_local; + + cs_prev_grad[cid] = dcs_local * f_local; + if (use_peephole) { + cs_prev_grad[cid] += di_local * wci[act_id] + df_local * wcf[act_id]; + } +} + +template +void LSTMBlockCellBpropWithCUDA( + OpKernelContext* ctx, const GPUDevice& d, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::ConstMatrix i, typename TTypes::ConstMatrix cs, + typename TTypes::ConstMatrix f, typename TTypes::ConstMatrix o, + typename TTypes::ConstMatrix ci, typename TTypes::ConstMatrix co, + typename TTypes::ConstMatrix cs_grad, + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, + typename TTypes::Matrix df, typename TTypes::Matrix di, + typename TTypes::Matrix dicfo, typename TTypes::Matrix cs_prev_grad, + typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, + typename TTypes::Vec wco_grad, const int batch_size, const int cell_size, + const bool use_peephole) { + const cudaStream_t& cu_stream = GetCudaStream(ctx); + + dim3 block_dim_2d(std::min(batch_size, 8), 32); + dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast(block_dim_2d.x)), + Eigen::divup(cell_size, static_cast(block_dim_2d.y))); + + lstm_gates_bprop<<>>( + cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(), + wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(), + co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(), + dci.data(), df.data(), di.data(), dicfo.data(), cs_prev_grad.data(), + batch_size, cell_size, use_peephole); + + if (use_peephole) { + Eigen::array p_shape({1, cell_size}); + Eigen::array p_broadcast_shape({batch_size, 1}); + cs_prev_grad.device(d) = + cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + + df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + wci_grad.device(d) = (di * cs_prev).sum(Eigen::array({0})); + wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array({0})); + wco_grad.device(d) = (do_ * cs).sum(Eigen::array({0})); + } +} + +} // namespace + #define DEFINE_GPU_SPECS(T) \ template struct TensorZero; \ template struct TensorUnalignedZero; \ @@ -49,9 +368,10 @@ typedef Eigen::GpuDevice GPUDevice; typename TTypes::Matrix f, typename TTypes::Matrix o, \ typename TTypes::Matrix ci, typename TTypes::Matrix co, \ typename TTypes::Matrix icfo, typename TTypes::Matrix h) { \ - LSTMBlockCellFpropWithEigen( \ - *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \ - h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h); \ + LSTMBlockCellFpropWithCUDA(ctx, d, forget_bias, cell_clip, use_peephole, \ + x, cs_prev, h_prev, w, wci, wcf, wco, b, xh, i, \ + cs, f, o, ci, co, icfo, h, batch_size_, \ + cell_size_, input_size_); \ } \ template <> \ void LSTMBlockCellBprop::operator()( \ @@ -73,10 +393,10 @@ typedef Eigen::GpuDevice GPUDevice; typename TTypes::Matrix cs_prev_grad, \ typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, \ typename TTypes::Vec wco_grad) { \ - LSTMBlockCellBpropWithEigen( \ - *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \ - i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, \ - cs_prev_grad, wci_grad, wcf_grad, wco_grad); \ + LSTMBlockCellBpropWithCUDA( \ + ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \ + cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, cs_prev_grad, wci_grad, \ + wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \ } \ template struct LSTMBlockCellFprop; \ template struct LSTMBlockCellBprop; \ diff --git a/tensorflow/contrib/rnn/python/kernel_tests/benchmarking.py b/tensorflow/contrib/rnn/python/kernel_tests/benchmarking.py new file mode 100644 index 0000000000000000000000000000000000000000..a48cd58706e72516f18098e643c0fa867d33beb2 --- /dev/null +++ b/tensorflow/contrib/rnn/python/kernel_tests/benchmarking.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library for benchmarking OpKernels.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import time + +from tensorflow.python.framework import ops + + +def device(use_gpu=False): + """TensorFlow device to assign ops to.""" + if use_gpu: + return ops.device("/gpu:0") + return ops.device("/cpu:0") + + +def seconds_per_run(op, sess, num_runs=50): + """Number of seconds taken to execute 'op' once on average.""" + for _ in range(2): + sess.run(op) + + start_time = time.time() + for _ in range(num_runs): + sess.run(op) + + end_time = time.time() + time_taken = (end_time - start_time) / num_runs + return time_taken + + +def dict_product(dicts): + """Constructs iterator over outer product of entries in a dict-of-lists. + + Example: + >>> dict_products({"a": [1,2], "b": [3, 4]}) + >>> [{"a": 1, "b": 3}, + {"a": 1, "b": 4}, + {"a": 2, "b": 3}, + {"a": 2, "b": 4}] + + Args: + dicts: dictionary with string keys and list values. + + Yields: + Individual dicts from outer product. + """ + keys, values = zip(*dicts.items()) + for config_values in itertools.product(*values): + yield dict(zip(keys, config_values)) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index deebadc142a3db25b6cf61f7fd5b3015d76e9181..6b6cdfa242aa5c85049f34c21fecc89e64c44ac4 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -40,7 +40,6 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test -from tensorflow.python.framework import test_util # pylint: enable=protected-access @@ -450,6 +449,17 @@ class RNNCellTest(test.TestCase): outputs, _ = cell(x, m) self.assertTrue("cpu:14159" in outputs.device.lower()) + def _retrieve_cpu_gpu_stats(self, run_metadata): + cpu_stats = None + gpu_stats = None + step_stats = run_metadata.step_stats + for ds in step_stats.dev_stats: + if "cpu:0" in ds.device[-5:].lower(): + cpu_stats = ds.node_stats + if "gpu:0" == ds.device[-5:].lower(): + gpu_stats = ds.node_stats + return cpu_stats, gpu_stats + def testDeviceWrapperDynamicExecutionNodesAreAllProperlyLocated(self): if not test.is_gpu_available(): # Can't perform this test w/o a GPU @@ -471,10 +481,7 @@ class RNNCellTest(test.TestCase): sess.run([variables_lib.global_variables_initializer()]) _ = sess.run(outputs, options=opts, run_metadata=run_metadata) - step_stats = run_metadata.step_stats - ix = 0 if gpu_dev in step_stats.dev_stats[0].device else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name]) self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 40a3fb2fb0b174681252265d593de2935ee2efa2..9cea2ec79a982e4fb362ec564eb72b3894917842 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -25,10 +25,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import rnn as rnn_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl @@ -167,7 +169,7 @@ class RNNTest(test.TestCase): self.assertEqual(out.get_shape(), inp.get_shape()) self.assertEqual(out.dtype, inp.dtype) - with self.test_session(use_gpu=False) as sess: + with self.test_session(use_gpu=True) as sess: input_value = np.random.randn(batch_size, input_size) values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) @@ -202,7 +204,7 @@ class RNNTest(test.TestCase): self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list()) self.assertEqual(out.dtype, inp.dtype) - with self.test_session(use_gpu=False) as sess: + with self.test_session(use_gpu=True) as sess: input_value = np.random.randn(batch_size, input_size) values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) full_dropout_values = sess.run(dropped_outputs, @@ -213,7 +215,7 @@ class RNNTest(test.TestCase): for d_v in full_dropout_values[:-1]: # Add 1.0 to dropped_out (all zeros) self.assertAllClose(d_v, np.ones_like(input_value)) - def _testDynamicCalculation(self, use_gpu): + def testDynamicCalculation(self): cell = Plus1RNNCell() sequence_length = array_ops.placeholder(dtypes.int64) batch_size = 2 @@ -228,7 +230,7 @@ class RNNTest(test.TestCase): cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32) self.assertEqual(len(dynamic_outputs), len(inputs)) - with self.test_session(use_gpu=use_gpu) as sess: + with self.test_session(use_gpu=True) as sess: input_value = np.random.randn(batch_size, input_size) dynamic_values = sess.run( dynamic_outputs, @@ -259,10 +261,6 @@ class RNNTest(test.TestCase): np.vstack((1.0 * (1 + 1) * np.ones((input_size)), 1.0 * (2 + 1) * np.ones((input_size))))) - def testDynamicCalculation(self): - self._testDynamicCalculation(True) - self._testDynamicCalculation(False) - def _testScope(self, factory, prefix="prefix", use_outer_scope=True): with self.test_session(use_gpu=True, graph=ops_lib.Graph()): if use_outer_scope: @@ -307,12 +305,12 @@ class LSTMTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) - def _testNoProjNoSharding(self, use_gpu): + def testNoProjNoSharding(self): num_units = 3 input_size = 5 batch_size = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) cell = rnn_cell.LSTMCell( @@ -330,12 +328,12 @@ class LSTMTest(test.TestCase): input_value = np.random.randn(batch_size, input_size) sess.run(outputs, feed_dict={inputs[0]: input_value}) - def _testCellClipping(self, use_gpu): + def testCellClipping(self): num_units = 3 input_size = 5 batch_size = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) cell = rnn_cell.LSTMCell( @@ -361,12 +359,12 @@ class LSTMTest(test.TestCase): # if cell c is clipped to 0, tanh(c) = 0 => m==0 self.assertAllEqual(value, np.zeros((batch_size, num_units))) - def _testNoProjNoShardingSimpleStateSaver(self, use_gpu): + def testNoProjNoShardingSimpleStateSaver(self): num_units = 3 input_size = 5 batch_size = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) state_saver = TestStateSaver(batch_size, 2 * num_units) @@ -491,13 +489,13 @@ class LSTMTest(test.TestCase): self.assertAllEqual(last_states[i], named_saved_states[flat_state_names[i]]) - def _testProjNoSharding(self, use_gpu): + def testProjNoSharding(self): num_units = 3 input_size = 5 batch_size = 2 num_proj = 4 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) inputs = max_length * [ @@ -582,7 +580,7 @@ class LSTMTest(test.TestCase): state_tuple_v = sess.run(state_tuple, feed_dict={inputs[0]: input_value}) self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v)) - def _testProjSharding(self, use_gpu): + def testProjSharding(self): num_units = 3 input_size = 5 batch_size = 2 @@ -590,7 +588,7 @@ class LSTMTest(test.TestCase): num_proj_shards = 3 num_unit_shards = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) @@ -616,7 +614,7 @@ class LSTMTest(test.TestCase): input_value = np.random.randn(batch_size, input_size) sess.run(outputs, feed_dict={inputs[0]: input_value}) - def _testDoubleInput(self, use_gpu): + def testDoubleInput(self): num_units = 3 input_size = 5 batch_size = 2 @@ -624,7 +622,7 @@ class LSTMTest(test.TestCase): num_proj_shards = 3 num_unit_shards = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed) inputs = max_length * [ array_ops.placeholder( @@ -653,7 +651,7 @@ class LSTMTest(test.TestCase): values = sess.run(outputs, feed_dict={inputs[0]: input_value}) self.assertEqual(values[0].dtype, input_value.dtype) - def _testShardNoShardEquivalentOutput(self, use_gpu): + def testShardNoShardEquivalentOutput(self): num_units = 3 input_size = 5 batch_size = 2 @@ -661,7 +659,7 @@ class LSTMTest(test.TestCase): num_proj_shards = 3 num_unit_shards = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: inputs = max_length * [ array_ops.placeholder( dtypes.float32, shape=(None, input_size)) @@ -708,7 +706,7 @@ class LSTMTest(test.TestCase): for (s_noshard, s_shard) in zip(state_values_noshard, state_values_shard): self.assertAllClose(s_noshard, s_shard, atol=1e-3) - def _testDoubleInputWithDropoutAndDynamicCalculation(self, use_gpu): + def testDoubleInputWithDropoutAndDynamicCalculation(self): """Smoke test for using LSTM with doubles, dropout, dynamic calculation.""" num_units = 3 @@ -718,7 +716,7 @@ class LSTMTest(test.TestCase): num_proj_shards = 3 num_unit_shards = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: sequence_length = array_ops.placeholder(dtypes.int64) initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) @@ -843,44 +841,13 @@ class LSTMTest(test.TestCase): for out0, out1 in zip(outputs0_values, outputs1_values): self.assertAllEqual(out0, out1) - def testNoProjNoShardingSimpleStateSaver(self): - self._testNoProjNoShardingSimpleStateSaver(use_gpu=False) - self._testNoProjNoShardingSimpleStateSaver(use_gpu=True) - - def testNoProjNoSharding(self): - self._testNoProjNoSharding(use_gpu=False) - self._testNoProjNoSharding(use_gpu=True) - - def testCellClipping(self): - self._testCellClipping(use_gpu=False) - self._testCellClipping(use_gpu=True) - - def testProjNoSharding(self): - self._testProjNoSharding(use_gpu=False) - self._testProjNoSharding(use_gpu=True) - - def testProjSharding(self): - self._testProjSharding(use_gpu=False) - self._testProjSharding(use_gpu=True) - - def testShardNoShardEquivalentOutput(self): - self._testShardNoShardEquivalentOutput(use_gpu=False) - self._testShardNoShardEquivalentOutput(use_gpu=True) - - def testDoubleInput(self): - self._testDoubleInput(use_gpu=False) - self._testDoubleInput(use_gpu=True) - - def testDoubleInputWithDropoutAndDynamicCalculation(self): - self._testDoubleInputWithDropoutAndDynamicCalculation(use_gpu=False) - self._testDoubleInputWithDropoutAndDynamicCalculation(use_gpu=True) - def testDynamicRNNAllowsUnknownTimeDimension(self): inputs = array_ops.placeholder(dtypes.float32, shape=[1, None, 20]) cell = rnn_cell.GRUCell(30) # Smoke test, this should not raise an error rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) + @test_util.run_in_graph_and_eager_modes() def testDynamicRNNWithTupleStates(self): num_units = 3 input_size = 5 @@ -888,13 +855,20 @@ class LSTMTest(test.TestCase): num_proj = 4 max_length = 8 sequence_length = [4, 6] + in_graph_mode = context.in_graph_mode() with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) - ] + if in_graph_mode: + inputs = max_length * [ + array_ops.placeholder( + dtypes.float32, shape=(None, input_size)) + ] + else: + inputs = max_length * [ + constant_op.constant( + np.random.randn(batch_size, input_size).astype(np.float32)) + ] inputs_c = array_ops.stack(inputs) cell = rnn_cell.LSTMCell( num_units, @@ -924,21 +898,34 @@ class LSTMTest(test.TestCase): self.assertEqual(state_dynamic[0], state_dynamic.c) self.assertEqual(state_dynamic[1], state_dynamic.h) - variables_lib.global_variables_initializer().run() - - input_value = np.random.randn(batch_size, input_size) - outputs_static_v = sess.run(outputs_static, - feed_dict={inputs[0]: input_value}) - outputs_dynamic_v = sess.run(outputs_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(outputs_static_v, outputs_dynamic_v) - - state_static_v = sess.run(state_static, - feed_dict={inputs[0]: input_value}) - state_dynamic_v = sess.run(state_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + if in_graph_mode: + variables_lib.global_variables_initializer().run() + input_value = np.random.randn(batch_size, input_size) + outputs_static = sess.run( + outputs_static, feed_dict={ + inputs[0]: input_value + }) + outputs_dynamic = sess.run( + outputs_dynamic, feed_dict={ + inputs[0]: input_value + }) + state_static = sess.run( + state_static, feed_dict={ + inputs[0]: input_value + }) + state_dynamic = sess.run( + state_dynamic, feed_dict={ + inputs[0]: input_value + }) + + if in_graph_mode: + self.assertAllEqual(outputs_static, outputs_dynamic) + else: + self.assertAllEqual( + array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy()) + self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) + @test_util.run_in_graph_and_eager_modes() def testDynamicRNNWithNestedTupleStates(self): num_units = 3 input_size = 5 @@ -946,13 +933,20 @@ class LSTMTest(test.TestCase): num_proj = 4 max_length = 8 sequence_length = [4, 6] + in_graph_mode = context.in_graph_mode() with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) - ] + if in_graph_mode: + inputs = max_length * [ + array_ops.placeholder( + dtypes.float32, shape=(None, input_size)) + ] + else: + inputs = max_length * [ + constant_op.constant( + np.random.randn(batch_size, input_size).astype(np.float32)) + ] inputs_c = array_ops.stack(inputs) def _cell(i): @@ -993,43 +987,58 @@ class LSTMTest(test.TestCase): sequence_length=sequence_length, scope=scope) - variables_lib.global_variables_initializer().run() - - input_value = np.random.randn(batch_size, input_size) - outputs_static_v = sess.run(outputs_static, - feed_dict={inputs[0]: input_value}) - outputs_dynamic_v = sess.run(outputs_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(outputs_static_v, outputs_dynamic_v) - - state_static_v = sess.run(nest.flatten(state_static), - feed_dict={inputs[0]: input_value}) - state_dynamic_v = sess.run(nest.flatten(state_dynamic), - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + if in_graph_mode: + input_value = np.random.randn(batch_size, input_size) + variables_lib.global_variables_initializer().run() + outputs_static = sess.run( + outputs_static, feed_dict={ + inputs[0]: input_value + }) + outputs_dynamic = sess.run( + outputs_dynamic, feed_dict={ + inputs[0]: input_value + }) + state_static = sess.run( + nest.flatten(state_static), feed_dict={ + inputs[0]: input_value + }) + state_dynamic = sess.run( + nest.flatten(state_dynamic), feed_dict={ + inputs[0]: input_value + }) + + if in_graph_mode: + self.assertAllEqual(outputs_static, outputs_dynamic) + else: + self.assertAllEqual( + array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy()) + state_static = [s.numpy() for s in nest.flatten(state_static)] + state_dynamic = [s.numpy() for s in nest.flatten(state_dynamic)] + self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) - def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length): + def _testDynamicEquivalentToStaticRNN(self, use_sequence_length): time_steps = 8 num_units = 3 num_proj = 4 input_size = 5 batch_size = 2 - input_values = np.random.randn(time_steps, batch_size, input_size) + input_values = np.random.randn(time_steps, batch_size, input_size).astype( + np.float32) if use_sequence_length: sequence_length = np.random.randint(0, time_steps, size=batch_size) else: sequence_length = None - ########### Step 1: Run static graph and generate readouts - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: - concat_inputs = array_ops.placeholder( - dtypes.float32, shape=(time_steps, batch_size, input_size)) - inputs = array_ops.unstack(concat_inputs) + in_graph_mode = context.in_graph_mode() + + # TODO(b/68017812): Eager ignores operation seeds, so we need to create a + # single cell and reuse it across the static and dynamic RNNs. Remove this + # special case once is fixed. + if not in_graph_mode: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - cell = rnn_cell.LSTMCell( num_units, use_peepholes=True, @@ -1037,63 +1046,85 @@ class LSTMTest(test.TestCase): num_proj=num_proj, state_is_tuple=False) + ########### Step 1: Run static graph and generate readouts + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: + if in_graph_mode: + concat_inputs = array_ops.placeholder( + dtypes.float32, shape=(time_steps, batch_size, input_size)) + else: + concat_inputs = constant_op.constant(input_values) + inputs = array_ops.unstack(concat_inputs) + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, seed=self._seed) + + # TODO(akshayka): Remove special case once b/68017812 is fixed. + if in_graph_mode: + cell = rnn_cell.LSTMCell( + num_units, + use_peepholes=True, + initializer=initializer, + num_proj=num_proj, + state_is_tuple=False) + with variable_scope.variable_scope("dynamic_scope"): outputs_static, state_static = rnn.static_rnn( cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32) - feeds = {concat_inputs: input_values} - - # Initialize - variables_lib.global_variables_initializer().run(feed_dict=feeds) - - # Generate gradients of sum of outputs w.r.t. inputs - static_gradients = gradients_impl.gradients( - outputs_static + [state_static], [concat_inputs]) - - # Generate gradients of individual outputs w.r.t. inputs - static_individual_gradients = nest.flatten([ - gradients_impl.gradients(y, [concat_inputs]) - for y in [outputs_static[0], outputs_static[-1], state_static] - ]) - - # Generate gradients of individual variables w.r.t. inputs - trainable_variables = ops_lib.get_collection( - ops_lib.GraphKeys.TRAINABLE_VARIABLES) - assert len(trainable_variables) > 1, ("Count of trainable variables: %d" % - len(trainable_variables)) - # pylint: disable=bad-builtin - static_individual_variable_gradients = nest.flatten([ - gradients_impl.gradients(y, trainable_variables) - for y in [outputs_static[0], outputs_static[-1], state_static] - ]) - - # Test forward pass - values_static = sess.run(outputs_static, feed_dict=feeds) - (state_value_static,) = sess.run((state_static,), feed_dict=feeds) - - # Test gradients to inputs and variables w.r.t. outputs & final state - static_grad_values = sess.run(static_gradients, feed_dict=feeds) - - static_individual_grad_values = sess.run(static_individual_gradients, - feed_dict=feeds) - - static_individual_var_grad_values = sess.run( - static_individual_variable_gradients, feed_dict=feeds) + if in_graph_mode: + # Generate gradients and run sessions to obtain outputs + feeds = {concat_inputs: input_values} + # Initialize + variables_lib.global_variables_initializer().run(feed_dict=feeds) + # Generate gradients of sum of outputs w.r.t. inputs + static_gradients = gradients_impl.gradients( + outputs_static + [state_static], [concat_inputs]) + # Generate gradients of individual outputs w.r.t. inputs + static_individual_gradients = nest.flatten([ + gradients_impl.gradients(y, [concat_inputs]) + for y in [outputs_static[0], outputs_static[-1], state_static] + ]) + # Generate gradients of individual variables w.r.t. inputs + trainable_variables = ops_lib.get_collection( + ops_lib.GraphKeys.TRAINABLE_VARIABLES) + assert len(trainable_variables) > 1, ( + "Count of trainable variables: %d" % len(trainable_variables)) + # pylint: disable=bad-builtin + static_individual_variable_gradients = nest.flatten([ + gradients_impl.gradients(y, trainable_variables) + for y in [outputs_static[0], outputs_static[-1], state_static] + ]) + # Test forward pass + values_static = sess.run(outputs_static, feed_dict=feeds) + (state_value_static,) = sess.run((state_static,), feed_dict=feeds) + + # Test gradients to inputs and variables w.r.t. outputs & final state + static_grad_values = sess.run(static_gradients, feed_dict=feeds) + + static_individual_grad_values = sess.run(static_individual_gradients, + feed_dict=feeds) + + static_individual_var_grad_values = sess.run( + static_individual_variable_gradients, feed_dict=feeds) ########## Step 2: Run dynamic graph and generate readouts - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: - concat_inputs = array_ops.placeholder( - dtypes.float32, shape=(time_steps, batch_size, input_size)) - inputs = array_ops.unstack(concat_inputs) + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: + if in_graph_mode: + concat_inputs = array_ops.placeholder( + dtypes.float32, shape=(time_steps, batch_size, input_size)) + else: + concat_inputs = constant_op.constant(input_values) initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - cell = rnn_cell.LSTMCell( - num_units, - use_peepholes=True, - initializer=initializer, - num_proj=num_proj, - state_is_tuple=False) + # TODO(akshayka): Remove this special case once b/68017812 is + # fixed. + if in_graph_mode: + cell = rnn_cell.LSTMCell( + num_units, + use_peepholes=True, + initializer=initializer, + num_proj=num_proj, + state_is_tuple=False) with variable_scope.variable_scope("dynamic_scope"): outputs_dynamic, state_dynamic = rnn.dynamic_rnn( @@ -1104,81 +1135,86 @@ class LSTMTest(test.TestCase): dtype=dtypes.float32) split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps) - feeds = {concat_inputs: input_values} + if in_graph_mode: + feeds = {concat_inputs: input_values} - # Initialize - variables_lib.global_variables_initializer().run(feed_dict=feeds) + # Initialize + variables_lib.global_variables_initializer().run(feed_dict=feeds) + + # Generate gradients of sum of outputs w.r.t. inputs + dynamic_gradients = gradients_impl.gradients( + split_outputs_dynamic + [state_dynamic], [concat_inputs]) - # Generate gradients of sum of outputs w.r.t. inputs - dynamic_gradients = gradients_impl.gradients( - split_outputs_dynamic + [state_dynamic], [concat_inputs]) - - # Generate gradients of several individual outputs w.r.t. inputs - dynamic_individual_gradients = nest.flatten([ - gradients_impl.gradients(y, [concat_inputs]) - for y in - [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] - ]) - - # Generate gradients of individual variables w.r.t. inputs - trainable_variables = ops_lib.get_collection( - ops_lib.GraphKeys.TRAINABLE_VARIABLES) - assert len(trainable_variables) > 1, ("Count of trainable variables: %d" % - len(trainable_variables)) - dynamic_individual_variable_gradients = nest.flatten([ - gradients_impl.gradients(y, trainable_variables) - for y in - [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] - ]) - - # Test forward pass - values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds) - (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds) - - # Test gradients to inputs and variables w.r.t. outputs & final state - dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) - - dynamic_individual_grad_values = sess.run(dynamic_individual_gradients, - feed_dict=feeds) - - dynamic_individual_var_grad_values = sess.run( - dynamic_individual_variable_gradients, feed_dict=feeds) + # Generate gradients of several individual outputs w.r.t. inputs + dynamic_individual_gradients = nest.flatten([ + gradients_impl.gradients(y, [concat_inputs]) + for y in + [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] + ]) + + # Generate gradients of individual variables w.r.t. inputs + trainable_variables = ops_lib.get_collection( + ops_lib.GraphKeys.TRAINABLE_VARIABLES) + assert len(trainable_variables) > 1, ( + "Count of trainable variables: %d" % len(trainable_variables)) + dynamic_individual_variable_gradients = nest.flatten([ + gradients_impl.gradients(y, trainable_variables) + for y in + [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] + ]) + + # Test forward pass + values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds) + (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds) + + # Test gradients to inputs and variables w.r.t. outputs & final state + dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) + + dynamic_individual_grad_values = sess.run(dynamic_individual_gradients, + feed_dict=feeds) + + dynamic_individual_var_grad_values = sess.run( + dynamic_individual_variable_gradients, feed_dict=feeds) ######### Step 3: Comparisons + if not in_graph_mode: + values_static = outputs_static + values_dynamic = split_outputs_dynamic + state_value_static = state_static + state_value_dynamic = state_dynamic + self.assertEqual(len(values_static), len(values_dynamic)) for (value_static, value_dynamic) in zip(values_static, values_dynamic): self.assertAllEqual(value_static, value_dynamic) self.assertAllEqual(state_value_static, state_value_dynamic) - self.assertAllEqual(static_grad_values, dynamic_grad_values) + if in_graph_mode: - self.assertEqual( - len(static_individual_grad_values), len(dynamic_individual_grad_values)) - self.assertEqual( - len(static_individual_var_grad_values), - len(dynamic_individual_var_grad_values)) + self.assertAllEqual(static_grad_values, dynamic_grad_values) - for i, (a, b) in enumerate( - zip(static_individual_grad_values, dynamic_individual_grad_values)): - tf_logging.info("Comparing individual gradients iteration %d" % i) - self.assertAllEqual(a, b) + self.assertEqual( + len(static_individual_grad_values), + len(dynamic_individual_grad_values)) + self.assertEqual( + len(static_individual_var_grad_values), + len(dynamic_individual_var_grad_values)) - for i, (a, b) in enumerate( - zip(static_individual_var_grad_values, - dynamic_individual_var_grad_values)): - tf_logging.info("Comparing individual variable gradients iteration %d" % - i) - self.assertAllEqual(a, b) + for i, (a, b) in enumerate( + zip(static_individual_grad_values, dynamic_individual_grad_values)): + tf_logging.info("Comparing individual gradients iteration %d" % i) + self.assertAllEqual(a, b) + for i, (a, b) in enumerate( + zip(static_individual_var_grad_values, + dynamic_individual_var_grad_values)): + tf_logging.info("Comparing individual variable gradients iteration %d" % + i) + self.assertAllEqual(a, b) + + @test_util.run_in_graph_and_eager_modes() def testDynamicEquivalentToStaticRNN(self): - self._testDynamicEquivalentToStaticRNN( - use_gpu=False, use_sequence_length=False) - self._testDynamicEquivalentToStaticRNN( - use_gpu=True, use_sequence_length=False) - self._testDynamicEquivalentToStaticRNN( - use_gpu=False, use_sequence_length=True) - self._testDynamicEquivalentToStaticRNN( - use_gpu=True, use_sequence_length=True) + self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) + self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) class BidirectionalRNNTest(test.TestCase): @@ -1188,7 +1224,6 @@ class BidirectionalRNNTest(test.TestCase): np.random.seed(self._seed) def _createBidirectionalRNN(self, - use_gpu, use_shape, use_sequence_length, scope=None): @@ -1227,10 +1262,10 @@ class BidirectionalRNNTest(test.TestCase): return input_value, inputs, outputs, state_fw, state_bw, sequence_length - def _testBidirectionalRNN(self, use_gpu, use_shape): - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + def _testBidirectionalRNN(self, use_shape): + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: input_value, inputs, outputs, state_fw, state_bw, sequence_length = ( - self._createBidirectionalRNN(use_gpu, use_shape, True)) + self._createBidirectionalRNN(use_shape, True)) variables_lib.global_variables_initializer().run() # Run with pre-specified sequence length of 2, 3 out, s_fw, s_bw = sess.run( @@ -1272,10 +1307,10 @@ class BidirectionalRNNTest(test.TestCase): # exactly the same self.assertAllClose(s_fw, s_bw) - def _testBidirectionalRNNWithoutSequenceLength(self, use_gpu, use_shape): - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + def _testBidirectionalRNNWithoutSequenceLength(self, use_shape): + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: input_value, inputs, outputs, state_fw, state_bw, _ = ( - self._createBidirectionalRNN(use_gpu, use_shape, False)) + self._createBidirectionalRNN(use_shape, False)) variables_lib.global_variables_initializer().run() out, s_fw, s_bw = sess.run([outputs, state_fw, state_bw], feed_dict={inputs[0]: input_value}) @@ -1302,23 +1337,14 @@ class BidirectionalRNNTest(test.TestCase): self.assertAllClose(s_fw, s_bw) def testBidirectionalRNN(self): - self._testBidirectionalRNN(use_gpu=False, use_shape=False) - self._testBidirectionalRNN(use_gpu=True, use_shape=False) - self._testBidirectionalRNN(use_gpu=False, use_shape=True) - self._testBidirectionalRNN(use_gpu=True, use_shape=True) + self._testBidirectionalRNN(use_shape=False) + self._testBidirectionalRNN(use_shape=True) def testBidirectionalRNNWithoutSequenceLength(self): - self._testBidirectionalRNNWithoutSequenceLength( - use_gpu=False, use_shape=False) - self._testBidirectionalRNNWithoutSequenceLength( - use_gpu=True, use_shape=False) - self._testBidirectionalRNNWithoutSequenceLength( - use_gpu=False, use_shape=True) - self._testBidirectionalRNNWithoutSequenceLength( - use_gpu=True, use_shape=True) + self._testBidirectionalRNNWithoutSequenceLength(use_shape=False) + self._testBidirectionalRNNWithoutSequenceLength(use_shape=True) def _createBidirectionalDynamicRNN(self, - use_gpu, use_shape, use_state_tuple, use_time_major, @@ -1366,11 +1392,11 @@ class BidirectionalRNNTest(test.TestCase): return input_value, inputs, outputs, state_fw, state_bw, sequence_length - def _testBidirectionalDynamicRNN(self, use_gpu, use_shape, use_state_tuple, + def _testBidirectionalDynamicRNN(self, use_shape, use_state_tuple, use_time_major, use_sequence_length): - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: input_value, inputs, outputs, state_fw, state_bw, sequence_length = ( - self._createBidirectionalDynamicRNN(use_gpu, use_shape, + self._createBidirectionalDynamicRNN(use_shape, use_state_tuple, use_time_major, use_sequence_length)) variables_lib.global_variables_initializer().run() @@ -1435,14 +1461,13 @@ class BidirectionalRNNTest(test.TestCase): def testBidirectionalDynamicRNN(self): # Generate 2^5 option values # from [True, True, True, True, True] to [False, False, False, False, False] - options = itertools.product([True, False], repeat=5) + options = itertools.product([True, False], repeat=4) for option in options: self._testBidirectionalDynamicRNN( - use_gpu=option[0], - use_shape=option[1], - use_state_tuple=option[2], - use_time_major=option[3], - use_sequence_length=option[4]) + use_shape=option[0], + use_state_tuple=option[1], + use_time_major=option[2], + use_sequence_length=option[3]) def _testScope(self, factory, prefix="prefix", use_outer_scope=True): # REMARKS: factory(scope) is a function accepting a scope @@ -1471,7 +1496,7 @@ class BidirectionalRNNTest(test.TestCase): def factory(scope): return self._createBidirectionalRNN( - use_gpu=True, use_shape=True, use_sequence_length=True, scope=scope) + use_shape=True, use_sequence_length=True, scope=scope) self._testScope(factory, use_outer_scope=True) self._testScope(factory, use_outer_scope=False) @@ -1483,7 +1508,6 @@ class BidirectionalRNNTest(test.TestCase): def factory(scope): return self._createBidirectionalDynamicRNN( - use_gpu=True, use_shape=True, use_state_tuple=True, use_sequence_length=True, @@ -1761,7 +1785,7 @@ class GRUTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) - def _testDynamic(self, use_gpu): + def testDynamic(self): time_steps = 8 num_units = 3 input_size = 5 @@ -1771,7 +1795,7 @@ class GRUTest(test.TestCase): sequence_length = np.random.randint(0, time_steps, size=batch_size) - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: concat_inputs = array_ops.placeholder( dtypes.float32, shape=(time_steps, batch_size, input_size)) @@ -1792,10 +1816,6 @@ class GRUTest(test.TestCase): sess.run([outputs_dynamic, state_dynamic], feed_dict=feeds) - def testDynamic(self): - self._testDynamic(use_gpu=False) - self._testDynamic(use_gpu=True) - def _testScope(self, factory, prefix="prefix", use_outer_scope=True): with self.test_session(use_gpu=True, graph=ops_lib.Graph()): if use_outer_scope: @@ -2203,6 +2223,17 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): return run_metadata + def _retrieve_cpu_gpu_stats(self, run_metadata): + cpu_stats = None + gpu_stats = None + step_stats = run_metadata.step_stats + for ds in step_stats.dev_stats: + if "cpu:0" in ds.device[-5:].lower(): + cpu_stats = ds.node_stats + if "gpu:0" == ds.device[-5:].lower(): + gpu_stats = ds.node_stats + return cpu_stats, gpu_stats + def testRNNOnCPUCellOnGPU(self): if not test.is_gpu_available(): return # Test requires access to a GPU @@ -2210,10 +2241,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): gpu_dev = test.gpu_device_name() run_metadata = self._execute_rnn_on( rnn_device="/cpu:0", cell_device=gpu_dev) - step_stats = run_metadata.step_stats - ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) def _assert_in(op_str, in_stats, out_stats): self.assertTrue(any(op_str in s.node_name for s in in_stats)) @@ -2236,10 +2264,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): run_metadata = self._execute_rnn_on( rnn_device="/cpu:0", cell_device="/cpu:0", input_device=gpu_dev) - step_stats = run_metadata.step_stats - ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) def _assert_in(op_str, in_stats, out_stats): self.assertTrue(any(op_str in s.node_name for s in in_stats)) @@ -2255,10 +2280,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): gpu_dev = test.gpu_device_name() run_metadata = self._execute_rnn_on( input_device=gpu_dev) - step_stats = run_metadata.step_stats - ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) def _assert_in(op_str, in_stats, out_stats): self.assertTrue(any(op_str in s.node_name for s in in_stats)) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py index 4239e32ab93043c5054e5382e67e79047b9644bb..b865466cc75aa67fcd192f7726f65141409b896a 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py @@ -18,10 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import time - import numpy as np +from tensorflow.contrib.rnn.python.kernel_tests import benchmarking from tensorflow.contrib.rnn.python.ops import gru_ops from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -333,20 +332,6 @@ class GRUBlockCellTest(test.TestCase): #### Benchmarking GRUBlockCell vs GRUCell. -def time_taken_by_op(op, sess, num_runs=50): - """Time taken by the Op.""" - for _ in range(2): - sess.run([op]) - - start_time = time.time() - for _ in range(num_runs): - sess.run([op]) - - end_time = time.time() - time_taken = end_time - start_time - return time_taken - - def training_gru_block_vs_gru_cell(batch_size, cell_size, input_size, @@ -357,7 +342,7 @@ def training_gru_block_vs_gru_cell(batch_size, ops.reset_default_graph() with session.Session(graph=ops.Graph()) as sess: # Specify the device which is been used. - with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"): + with benchmarking.device(use_gpu): # Random initializers. seed = 1994 @@ -387,7 +372,8 @@ def training_gru_block_vs_gru_cell(batch_size, learning_rate).minimize(cost) # time for a training step. - basic_time_training = time_taken_by_op(optimizer, sess, iters) + basic_time_training = benchmarking.seconds_per_run( + optimizer, sess, iters) # Output from the basic GRU cell implementation. with vs.variable_scope("block", initializer=initializer): @@ -406,7 +392,8 @@ def training_gru_block_vs_gru_cell(batch_size, learning_rate).minimize(cost) # time for a training step. - block_time_training = time_taken_by_op(optimizer, sess, iters) + block_time_training = benchmarking.seconds_per_run( + optimizer, sess, iters) performance_training = ( basic_time_training - block_time_training) * 100 / basic_time_training @@ -429,7 +416,7 @@ def inference_gru_block_vs_gru_cell(batch_size, """Benchmark inference speed between GRUBlockCell vs GRUCell.""" ops.reset_default_graph() with session.Session(graph=ops.Graph()) as sess: - with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"): + with benchmarking.device(use_gpu): # Random initializers. seed = 1994 @@ -451,7 +438,8 @@ def inference_gru_block_vs_gru_cell(batch_size, time_major=True, dtype=dtypes.float32) sess.run([variables.global_variables_initializer()]) - basic_time_inference = time_taken_by_op(outputs_dynamic, sess, iters) + basic_time_inference = benchmarking.seconds_per_run( + outputs_dynamic, sess, iters) # Output from the block GRU cell implementation. with vs.variable_scope("block", initializer=initializer): @@ -463,7 +451,8 @@ def inference_gru_block_vs_gru_cell(batch_size, time_major=True, dtype=dtypes.float32) sess.run([variables.global_variables_initializer()]) - block_time_inference = time_taken_by_op(outputs_dynamic, sess, iters) + block_time_inference = benchmarking.seconds_per_run( + outputs_dynamic, sess, iters) performance_inference = (basic_time_inference - block_time_inference ) * 100 / basic_time_inference @@ -484,7 +473,7 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size, """Benchmark single bprop step speed between GRUBlockCell vs GRUCell.""" ops.reset_default_graph() with session.Session(graph=ops.Graph()) as sess: - with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"): + with benchmarking.device(use_gpu): initializer = init_ops.random_uniform_initializer(-1, 1, seed=1989) # Inputs x = vs.get_variable("x", [batch_size, input_size]) @@ -496,7 +485,8 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size, array_ops.identity(h)) sess.run([variables.global_variables_initializer()]) grad_output_wrt_input = gradients_impl.gradients([output], h) - basic_time_bprop = time_taken_by_op(grad_output_wrt_input, sess, iters) + basic_time_bprop = benchmarking.seconds_per_run(grad_output_wrt_input, + sess, iters) # Output from the block GRU cell implementation. with vs.variable_scope("block", initializer=initializer): @@ -504,7 +494,8 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size, array_ops.identity(h)) sess.run([variables.global_variables_initializer()]) grad_output_wrt_input = gradients_impl.gradients([output], h) - block_time_bprop = time_taken_by_op(grad_output_wrt_input, sess, iters) + block_time_bprop = benchmarking.seconds_per_run(grad_output_wrt_input, + sess, iters) performance_inference = ( basic_time_bprop - block_time_bprop) * 100 / basic_time_bprop @@ -526,23 +517,29 @@ class BenchmarkGRUBlock(test.Benchmark): print("batch_size, cell_size, input_size, time_steps, GPU, " "basic_time_training, block_time_training, performance_training[%]") iters = 10 - for use_gpu in [True, False]: - for batch_size in [1, 32, 128]: - for cell_size in [128, 512]: - for input_size in [128, 512]: - for time_steps in [50]: - basic_time, block_time = training_gru_block_vs_gru_cell( - batch_size, cell_size, input_size, time_steps, use_gpu, iters) - self.report_benchmark( - name="GRUCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % - (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=basic_time) - self.report_benchmark( - name="GRUBlockCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % - (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=block_time) + + for config in benchmarking.dict_product({ + "use_gpu": [True, False], + "batch_size": [1, 32, 128], + "cell_size": [128, 512], + "input_size": [128, 512], + "time_steps": [50] + }): + basic_time, block_time = training_gru_block_vs_gru_cell( + config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"], iters) + self.report_benchmark( + name="GRUCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=basic_time) + self.report_benchmark( + name="GRUBlockCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=block_time) def benchmarkInferenceBlockGRUVsGRUCell(self): print("--------------------------------------------------------------") @@ -551,23 +548,28 @@ class BenchmarkGRUBlock(test.Benchmark): "batch_size, cell_size, input_size, time_steps, GPU, " "basic_time_inference, block_time_inference, performance_inference[%]") iters = 10 - for use_gpu in [True, False]: - for batch_size in [1, 32, 128]: - for cell_size in [128, 512]: - for input_size in [128, 512]: - for time_steps in [50]: - basic_time, block_time = inference_gru_block_vs_gru_cell( - batch_size, cell_size, input_size, time_steps, use_gpu, iters) - self.report_benchmark( - name="GRUCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % - (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=basic_time) - self.report_benchmark( - name="GRUBlockCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" - % (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=block_time) + for config in benchmarking.dict_product({ + "use_gpu": [True, False], + "batch_size": [1, 32, 128], + "cell_size": [128, 512], + "input_size": [128, 512], + "time_steps": [50] + }): + basic_time, block_time = inference_gru_block_vs_gru_cell( + config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"], iters) + self.report_benchmark( + name="GRUCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=basic_time) + self.report_benchmark( + name="GRUBlockCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=block_time) def benchmarkSingleBpropStepBlockGRUVsGRUCell(self): print("--------------------------------------------------------------") @@ -575,22 +577,27 @@ class BenchmarkGRUBlock(test.Benchmark): print("batch_size, cell_size, input_size, GPU, basic_time, " "block_time, performance_inference[%]") iters = 10 - for use_gpu in [True, False]: - for batch_size in [1, 32, 128]: - for cell_size in [128, 512]: - for input_size in [128, 512]: - basic_time, block_time = single_bprop_step_gru_block_vs_gru_cell( - batch_size, cell_size, input_size, use_gpu, iters) - self.report_benchmark( - name="GRUCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" % - (batch_size, cell_size, input_size, use_gpu), - iters=iters, - wall_time=basic_time) - self.report_benchmark( - name="GRUBlockCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" - % (batch_size, cell_size, input_size, use_gpu), - iters=iters, - wall_time=block_time) + for config in benchmarking.dict_product({ + "use_gpu": [True, False], + "batch_size": [1, 32, 128], + "cell_size": [128, 512], + "input_size": [128, 512] + }): + basic_time, block_time = single_bprop_step_gru_block_vs_gru_cell( + config["batch_size"], config["cell_size"], config["input_size"], + config["use_gpu"], iters) + self.report_benchmark( + name="GRUCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["use_gpu"]), + iters=iters, + wall_time=basic_time) + self.report_benchmark( + name="GRUBlockCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["use_gpu"]), + iters=iters, + wall_time=block_time) print("--------------------------------------------------------------") diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py index 0ec37411f5f3d9b6687c077bf967b046068644ab..a288072ae5da0751f1999128029f38bea933490e 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py @@ -20,7 +20,9 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.rnn.python.kernel_tests import benchmarking from tensorflow.contrib.rnn.python.ops import lstm_ops +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -36,6 +38,111 @@ from tensorflow.python.platform import test block_lstm = lstm_ops._block_lstm # pylint: disable=protected-access +def blocks_match(sess, use_peephole): + batch_size = 2 + input_size = 3 + cell_size = 4 + sequence_length = 4 + + inputs = [] + for _ in range(sequence_length): + inp = ops.convert_to_tensor( + np.random.randn(batch_size, input_size), dtype=dtypes.float32) + inputs.append(inp) + + initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=19890212) + + with variable_scope.variable_scope("test", initializer=initializer): + # magic naming so that the cells pick up these variables and resuse them + if use_peephole: + wci = variable_scope.get_variable( + "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtypes.float32) + wcf = variable_scope.get_variable( + "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtypes.float32) + wco = variable_scope.get_variable( + "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtypes.float32) + + w = variable_scope.get_variable( + "rnn/lstm_cell/kernel", + shape=[input_size + cell_size, cell_size * 4], + dtype=dtypes.float32) + b = variable_scope.get_variable( + "rnn/lstm_cell/bias", + shape=[cell_size * 4], + dtype=dtypes.float32, + initializer=init_ops.zeros_initializer()) + + if use_peephole: + wci_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/w_i_diag", + initializer=wci.initialized_value()) + wcf_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/w_f_diag", + initializer=wcf.initialized_value()) + wco_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/w_o_diag", + initializer=wco.initialized_value()) + w_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/kernel", + initializer=w.initialized_value()) + b_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/bias", + initializer=b.initialized_value()) + + basic_cell = rnn_cell.LSTMCell( + cell_size, use_peepholes=use_peephole, state_is_tuple=True, reuse=True) + basic_outputs_op, basic_state_op = rnn.static_rnn( + basic_cell, inputs, dtype=dtypes.float32) + + if use_peephole: + _, _, _, _, _, _, block_outputs_op = block_lstm( + ops.convert_to_tensor(sequence_length, dtype=dtypes.int64), + inputs, + w, + b, + wci=wci, + wcf=wcf, + wco=wco, + cell_clip=0, + use_peephole=True) + else: + _, _, _, _, _, _, block_outputs_op = block_lstm( + ops.convert_to_tensor(sequence_length, dtype=dtypes.int64), + inputs, + w, + b, + cell_clip=0) + + with variable_scope.variable_scope("rnn/lstm_cell", reuse=True): + fused_cell = lstm_ops.LSTMBlockFusedCell( + cell_size, cell_clip=0, use_peephole=use_peephole) + fused_outputs_op, fused_state_op = fused_cell( + inputs, dtype=dtypes.float32) + + sess.run([variables.global_variables_initializer()]) + basic_outputs, basic_state = sess.run([basic_outputs_op, basic_state_op[0]]) + basic_grads = sess.run(gradients_impl.gradients(basic_outputs_op, inputs)) + xs = [w, b] + if use_peephole: + xs += [wci, wcf, wco] + basic_wgrads = sess.run(gradients_impl.gradients(basic_outputs_op, xs)) + + block_outputs = sess.run(block_outputs_op) + block_grads = sess.run(gradients_impl.gradients(block_outputs_op, inputs)) + block_wgrads = sess.run(gradients_impl.gradients(block_outputs_op, xs)) + + xs = [w_block, b_block] + if use_peephole: + xs += [wci_block, wcf_block, wco_block] + fused_outputs, fused_state = sess.run([fused_outputs_op, fused_state_op[0]]) + fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs)) + fused_wgrads = sess.run(gradients_impl.gradients(fused_outputs_op, xs)) + + return (basic_state, fused_state, basic_outputs, block_outputs, + fused_outputs, basic_grads, block_grads, fused_grads, basic_wgrads, + block_wgrads, fused_wgrads) + + class LSTMBlockCellTest(test.TestCase): def testNoneDimsWithDynamicRNN(self): @@ -225,164 +332,39 @@ class LSTMBlockCellTest(test.TestCase): def testLSTMBasicToBlock(self): with self.test_session(use_gpu=True) as sess: - batch_size = 2 - input_size = 3 - cell_size = 4 - sequence_length = 5 - - inputs = [] - for _ in range(sequence_length): - inp = ops.convert_to_tensor( - np.random.randn(batch_size, input_size), dtype=dtypes.float32) - inputs.append(inp) - - initializer = init_ops.random_uniform_initializer( - -0.01, 0.01, seed=19890212) - with variable_scope.variable_scope("basic", initializer=initializer): - cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True) - outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - basic_outputs, basic_state = sess.run([outputs, state[0]]) - basic_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - basic_wgrads = sess.run( - gradients_impl.gradients(outputs, variables.trainable_variables())) - - with variable_scope.variable_scope("block", initializer=initializer): - w = variable_scope.get_variable( - "w", - shape=[input_size + cell_size, cell_size * 4], - dtype=dtypes.float32) - b = variable_scope.get_variable( - "b", - shape=[cell_size * 4], - dtype=dtypes.float32, - initializer=init_ops.zeros_initializer()) - - _, _, _, _, _, _, outputs = block_lstm( - ops.convert_to_tensor( - sequence_length, dtype=dtypes.int64), - inputs, - w, - b, - cell_clip=0) - - sess.run([variables.global_variables_initializer()]) - block_outputs = sess.run(outputs) - block_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - block_wgrads = sess.run(gradients_impl.gradients(outputs, [w, b])) + (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs, + basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads, + fused_wgrads) = blocks_match( + sess, use_peephole=False) self.assertAllClose(basic_outputs, block_outputs) self.assertAllClose(basic_grads, block_grads) for basic, block in zip(basic_wgrads, block_wgrads): - self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2) - - with variable_scope.variable_scope("fused", initializer=initializer): - cell = lstm_ops.LSTMBlockFusedCell( - cell_size, cell_clip=0, use_peephole=False) - outputs, state = cell(inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - fused_outputs, fused_state = sess.run([outputs, state[0]]) - fused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - fused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("fused/") - ] - fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars)) + self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6) self.assertAllClose(basic_outputs, fused_outputs) self.assertAllClose(basic_state, fused_state) self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(basic_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) + for basic, fused in zip(block_wgrads, fused_wgrads): + self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6) def testLSTMBasicToBlockPeeping(self): with self.test_session(use_gpu=True) as sess: - batch_size = 2 - input_size = 3 - cell_size = 4 - sequence_length = 5 - - inputs = [] - for _ in range(sequence_length): - inp = ops.convert_to_tensor( - np.random.randn(batch_size, input_size), dtype=dtypes.float32) - inputs.append(inp) - - initializer = init_ops.random_uniform_initializer( - -0.01, 0.01, seed=19890212) - with variable_scope.variable_scope("basic", initializer=initializer): - cell = rnn_cell.LSTMCell( - cell_size, use_peepholes=True, state_is_tuple=True) - outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - basic_outputs, basic_state = sess.run([outputs, state[0]]) - basic_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - basic_wgrads = sess.run( - gradients_impl.gradients(outputs, variables.trainable_variables())) - - with variable_scope.variable_scope("block", initializer=initializer): - w = variable_scope.get_variable( - "w", - shape=[input_size + cell_size, cell_size * 4], - dtype=dtypes.float32) - b = variable_scope.get_variable( - "b", - shape=[cell_size * 4], - dtype=dtypes.float32, - initializer=init_ops.zeros_initializer()) - - wci = variable_scope.get_variable( - "wci", shape=[cell_size], dtype=dtypes.float32) - wcf = variable_scope.get_variable( - "wcf", shape=[cell_size], dtype=dtypes.float32) - wco = variable_scope.get_variable( - "wco", shape=[cell_size], dtype=dtypes.float32) - - _, _, _, _, _, _, outputs = block_lstm( - ops.convert_to_tensor( - sequence_length, dtype=dtypes.int64), - inputs, - w, - b, - wci=wci, - wcf=wcf, - wco=wco, - cell_clip=0, - use_peephole=True) - - sess.run([variables.global_variables_initializer()]) - block_outputs = sess.run(outputs) - block_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - block_wgrads = sess.run( - gradients_impl.gradients(outputs, [w, b, wci, wcf, wco])) + (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs, + basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads, + fused_wgrads) = blocks_match( + sess, use_peephole=True) self.assertAllClose(basic_outputs, block_outputs) self.assertAllClose(basic_grads, block_grads) for basic, block in zip(basic_wgrads, block_wgrads): - self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2) - - with variable_scope.variable_scope("fused", initializer=initializer): - cell = lstm_ops.LSTMBlockFusedCell( - cell_size, cell_clip=0, use_peephole=True) - outputs, state = cell(inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - fused_outputs, fused_state = sess.run([outputs, state[0]]) - fused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - fused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("fused/") - ] - fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars)) + self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6) self.assertAllClose(basic_outputs, fused_outputs) self.assertAllClose(basic_state, fused_state) self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(basic_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) + for basic, fused in zip(block_wgrads, fused_wgrads): + self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6) def testLSTMFusedSequenceLengths(self): """Verify proper support for sequence lengths in LSTMBlockFusedCell.""" @@ -401,45 +383,40 @@ class LSTMBlockCellTest(test.TestCase): initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=19890213) - with variable_scope.variable_scope("basic", initializer=initializer): - cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True) - outputs, state = rnn.static_rnn( - cell, inputs, dtype=dtypes.float32, sequence_length=seq_lengths) - sess.run([variables.global_variables_initializer()]) - basic_outputs, basic_state = sess.run([outputs, state[0]]) - basic_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - basic_wgrads = sess.run( - gradients_impl.gradients(outputs, variables.trainable_variables())) - with variable_scope.variable_scope("fused", initializer=initializer): + with variable_scope.variable_scope( + "lstm_block_wrapper", initializer=initializer): + # magic naming so that the cells pick up these variables and resuse them + variable_scope.get_variable( + "kernel", + shape=[input_size + cell_size, cell_size * 4], + dtype=dtypes.float32) + + variable_scope.get_variable( + "bias", + shape=[cell_size * 4], + dtype=dtypes.float32, + initializer=init_ops.zeros_initializer()) + + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): cell = lstm_ops.LSTMBlockFusedCell( cell_size, cell_clip=0, use_peephole=False) - outputs, state = cell( - inputs, dtype=dtypes.float32, sequence_length=seq_lengths) - sess.run([variables.global_variables_initializer()]) - fused_outputs, fused_state = sess.run([outputs, state[0]]) - fused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - fused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("fused/") - ] - fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars)) + fused_outputs_op, fused_state_op = cell( + inputs, dtype=dtypes.float32, sequence_length=seq_lengths) - self.assertAllClose(basic_outputs, fused_outputs) - self.assertAllClose(basic_state, fused_state) - self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(basic_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) + cell_vars = [ + v for v in variables.trainable_variables() + if v.name.endswith("kernel") or v.name.endswith("bias") + ] # Verify that state propagation works if we turn our sequence into # tiny (single-time) subsequences, i.e. unfuse the cell + unfused_outputs_op = [] + state = None with variable_scope.variable_scope( - "unfused", initializer=initializer) as vs: - cell = lstm_ops.LSTMBlockFusedCell( - cell_size, cell_clip=0, use_peephole=False) - outputs = [] - state = None + variable_scope.get_variable_scope(), reuse=True): for i, inp in enumerate(inputs): lengths = [int(i < l) for l in seq_lengths.eval()] output, state = cell( @@ -447,25 +424,136 @@ class LSTMBlockCellTest(test.TestCase): initial_state=state, dtype=dtypes.float32, sequence_length=lengths) - vs.reuse_variables() - outputs.append(output[0]) - outputs = array_ops.stack(outputs) - - sess.run([variables.global_variables_initializer()]) - unfused_outputs, unfused_state = sess.run([outputs, state[0]]) - unfused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - unfused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("unfused/") - ] - unfused_wgrads = sess.run( - gradients_impl.gradients(outputs, unfused_vars)) - - self.assertAllClose(basic_outputs, unfused_outputs) - self.assertAllClose(basic_state, unfused_state) - self.assertAllClose(basic_grads, unfused_grads) - for basic, unfused in zip(basic_wgrads, unfused_wgrads): - self.assertAllClose(basic, unfused, rtol=1e-2, atol=1e-2) + unfused_outputs_op.append(output[0]) + unfused_outputs_op = array_ops.stack(unfused_outputs_op) + + sess.run([variables.global_variables_initializer()]) + unfused_outputs, unfused_state = sess.run([unfused_outputs_op, state[0]]) + unfused_grads = sess.run( + gradients_impl.gradients(unfused_outputs_op, inputs)) + unfused_wgrads = sess.run( + gradients_impl.gradients(unfused_outputs_op, cell_vars)) + + fused_outputs, fused_state = sess.run( + [fused_outputs_op, fused_state_op[0]]) + fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs)) + fused_wgrads = sess.run( + gradients_impl.gradients(fused_outputs_op, cell_vars)) + + self.assertAllClose(fused_outputs, unfused_outputs) + self.assertAllClose(fused_state, unfused_state) + self.assertAllClose(fused_grads, unfused_grads) + for fused, unfused in zip(fused_wgrads, unfused_wgrads): + self.assertAllClose(fused, unfused, rtol=1e-6, atol=1e-6) + +#### Benchmarking. + + +class BenchmarkLSTMBlock(test.Benchmark): + + def benchmarkLSTMBlockCellFpropWithDynamicRNN(self): + print("BlockLSTMCell forward propagation via dynamic_rnn().") + print("--------------------------------------------------------------") + print("LSTMBlockCell Seconds per inference.") + print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time") + iters = 10 + for config in benchmarking.dict_product({ + "batch_size": [1, 8, 13, 32, 67, 128], + "cell_size": [128, 250, 512, 650, 1024, 1350], + "time_steps": [40], + "use_gpu": [True, False] + }): + with ops.Graph().as_default(): + with benchmarking.device(use_gpu=config["use_gpu"]): + inputs = variable_scope.get_variable( + "x", + [config["time_steps"], config["batch_size"], config["cell_size"]]) + cell = lstm_ops.LSTMBlockCell(config["cell_size"]) + outputs = rnn.dynamic_rnn( + cell, inputs, time_major=True, dtype=dtypes.float32) + init_op = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init_op) + wall_time = benchmarking.seconds_per_run(outputs, sess, iters) + + # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable + # is set, this will produce a copy-paste-able CSV file. + print(",".join( + map(str, [ + config["batch_size"], config["cell_size"], config["cell_size"], + config["time_steps"], config["use_gpu"], wall_time + ]))) + benchmark_name_template = "_".join([ + "LSTMBlockCell_fprop", "BS%(batch_size)i", "CS%(cell_size)i", + "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s" + ]) + + self.report_benchmark( + name=benchmark_name_template % config, + iters=iters, + wall_time=wall_time, + extras=config) + + def benchmarkLSTMBlockCellBpropWithDynamicRNN(self): + print("BlockLSTMCell backward propagation via dynamic_rnn().") + print("--------------------------------------------------------------") + print("LSTMBlockCell Seconds per inference.") + print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time") + iters = 10 + for config in benchmarking.dict_product({ + "batch_size": [1, 8, 13, 32, 67, 128], + "cell_size": [128, 250, 512, 650, 1024, 1350], + "time_steps": [40], + "use_gpu": [True, False] + }): + with ops.Graph().as_default(): + with benchmarking.device(use_gpu=config["use_gpu"]): + time_steps = config["time_steps"] + batch_size = config["batch_size"] + cell_size = input_size = config["cell_size"] + inputs = variable_scope.get_variable( + "x", [time_steps, batch_size, cell_size], + trainable=False, + dtype=dtypes.float32) + with variable_scope.variable_scope( + "rnn", reuse=variable_scope.AUTO_REUSE): + w = variable_scope.get_variable( + "rnn/lstm_cell/kernel", + shape=[input_size + cell_size, cell_size * 4], + dtype=dtypes.float32) + b = variable_scope.get_variable( + "rnn/lstm_cell/bias", + shape=[cell_size * 4], + dtype=dtypes.float32, + initializer=init_ops.zeros_initializer()) + cell = lstm_ops.LSTMBlockCell(cell_size) + outputs = rnn.dynamic_rnn( + cell, inputs, time_major=True, dtype=dtypes.float32) + grads = gradients_impl.gradients(outputs, [inputs, w, b]) + init_op = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init_op) + wall_time = benchmarking.seconds_per_run(grads, sess, iters) + + # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable + # is set, this will produce a copy-paste-able CSV file. + print(",".join( + map(str, [ + batch_size, cell_size, cell_size, time_steps, config["use_gpu"], + wall_time + ]))) + benchmark_name_template = "_".join([ + "LSTMBlockCell_bprop", "BS%(batch_size)i", "CS%(cell_size)i", + "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s" + ]) + + self.report_benchmark( + name=benchmark_name_template % config, + iters=iters, + wall_time=wall_time, + extras=config) if __name__ == "__main__": diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 352dae3acf75240c000f0339982a3652d9466200..df910a3423083972bdee42bec10733e37b8e5f96 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -116,8 +116,8 @@ def _lstm_block_cell(x, if cell_size is None: raise ValueError("cell_size from `cs_prev` should not be None.") wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) - wco = wci wcf = wci + wco = wci # pylint: disable=protected-access return gen_lstm_ops.lstm_block_cell( @@ -126,8 +126,8 @@ def _lstm_block_cell(x, h_prev=h_prev, w=w, wci=wci, - wco=wco, wcf=wcf, + wco=wco, b=b, forget_bias=forget_bias, cell_clip=cell_clip if cell_clip is not None else -1, @@ -201,8 +201,8 @@ def _block_lstm(seq_len_max, h_prev = zero_state if wci is None: wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) - wco = wci wcf = wci + wco = wci # pylint: disable=protected-access i, cs, f, o, ci, co, h = gen_lstm_ops.block_lstm( @@ -212,8 +212,8 @@ def _block_lstm(seq_len_max, h_prev=h_prev, w=w, wci=wci, - wco=wco, wcf=wcf, + wco=wco, b=b, forget_bias=forget_bias, cell_clip=cell_clip if cell_clip is not None else -1, @@ -233,7 +233,7 @@ _lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"] @ops.RegisterGradient("LSTMBlockCell") def _LSTMBlockCellGrad(op, *grad): """Gradient for LSTMBlockCell.""" - (x, cs_prev, h_prev, w, wci, wco, wcf, b) = op.inputs + (x, cs_prev, h_prev, w, wci, wcf, wco, b) = op.inputs (i, cs, f, o, ci, co, _) = op.outputs (_, cs_grad, _, _, _, _, h_grad) = grad @@ -293,13 +293,13 @@ def _LSTMBlockCellGrad(op, *grad): @ops.RegisterGradient("BlockLSTM") def _BlockLSTMGrad(op, *grad): """Gradient for BlockLSTM.""" - seq_len_max, x, cs_prev, h_prev, w, wci, wco, wcf, b = op.inputs + seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs i, cs, f, o, ci, co, h = op.outputs cs_grad = grad[1] h_grad = grad[6] - (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, wcf_grad, + (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad, b_grad) = gen_lstm_ops.block_lstm_grad( seq_len_max, x, @@ -307,8 +307,8 @@ def _BlockLSTMGrad(op, *grad): h_prev, w, wci, - wco, wcf, + wco, b, i, cs, @@ -321,8 +321,10 @@ def _BlockLSTMGrad(op, *grad): h_grad, use_peephole=op.get_attr("use_peephole")) - return [None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, - wcf_grad, b_grad] + return [ + None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, + wco_grad, b_grad + ] class LSTMBlockCell(rnn_cell_impl.RNNCell): @@ -367,8 +369,8 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): "W": "kernel", "b": "bias", "wci": "w_i_diag", - "wco": "w_o_diag", "wcf": "w_f_diag", + "wco": "w_o_diag", "scope": "lstm_cell" } @@ -396,10 +398,10 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): initializer=init_ops.constant_initializer(0.0)) if self._use_peephole: wci = vs.get_variable(self._names["wci"], [self._num_units]) - wco = vs.get_variable(self._names["wco"], [self._num_units]) wcf = vs.get_variable(self._names["wcf"], [self._num_units]) + wco = vs.get_variable(self._names["wco"], [self._num_units]) else: - wci = wco = wcf = array_ops.zeros([self._num_units]) + wci = wcf = wco = array_ops.zeros([self._num_units]) (cs_prev, h_prev) = states_prev (_, cs, _, _, _, _, h) = _lstm_block_cell( x, @@ -408,8 +410,8 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): w, b, wci=wci, - wco=wco, wcf=wcf, + wco=wco, forget_bias=self._forget_bias, cell_clip=self._cell_clip, use_peephole=self._use_peephole) @@ -644,10 +646,10 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): dtype=dtype) if self._use_peephole: wci = vs.get_variable("w_i_diag", [self._num_units], dtype=dtype) - wco = vs.get_variable("w_o_diag", [self._num_units], dtype=dtype) wcf = vs.get_variable("w_f_diag", [self._num_units], dtype=dtype) + wco = vs.get_variable("w_o_diag", [self._num_units], dtype=dtype) else: - wci = wco = wcf = array_ops.zeros([self._num_units], dtype=dtype) + wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype) if sequence_length is None: max_seq_len = math_ops.to_int64(time_len) @@ -661,8 +663,8 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): h_prev=initial_output, w=w, wci=wci, - wco=wco, wcf=wcf, + wco=wco, b=b, forget_bias=self._forget_bias, cell_clip=self._cell_clip, diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 7cb1e7f364143602e19a754d3c5300ee32dbf37c..6702a89d22283be691deb11339d3374bf8c4fd93 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -525,7 +525,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): self._state_tuple_type = collections.namedtuple( "GridLSTMStateTuple", state_names.strip(",")) self._state_size = self._state_tuple_type( - *([num_units, num_units] * self._total_blocks)) + *([num_units, num_units] * self._total_blocks)) else: self._state_tuple_type = None self._state_size = num_units * self._total_blocks * 2 @@ -2082,9 +2082,11 @@ def _conv(args, shape_length = len(shapes[0]) for shape in shapes: if len(shape) not in [3,4,5]: - raise ValueError("Conv Linear expects 3D, 4D or 5D arguments: %s" % str(shapes)) + raise ValueError("Conv Linear expects 3D, 4D " + "or 5D arguments: %s" % str(shapes)) if len(shape) != len(shapes[0]): - raise ValueError("Conv Linear expects all args to be of same Dimension: %s" % str(shapes)) + raise ValueError("Conv Linear expects all args " + "to be of same Dimension: %s" % str(shapes)) else: total_arg_size_depth += shape[-1] dtype = [a.dtype for a in args][0] @@ -2102,7 +2104,7 @@ def _conv(args, # Now the computation. kernel = vs.get_variable( - "kernel", + "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype) if len(args) == 1: diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc index aab0f3f4947388741765b268094b4136d356a457..64973ccccdc962757a727d7183bd70e94edcfd1b 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -49,40 +49,46 @@ class GatherTreeOp : public OpKernel { const Device& device = ctx->eigen_device(); const Tensor& step_ids = ctx->input(0); const Tensor& parent_ids = ctx->input(1); - const Tensor& sequence_length = ctx->input(2); + const Tensor& max_sequence_lengths = ctx->input(2); + const Tensor& end_token = ctx->input(3); const TensorShape& step_ids_shape = step_ids.shape(); OP_REQUIRES( ctx, step_ids_shape.dims() == 3, errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ", step_ids_shape.DebugString())); - OP_REQUIRES( - ctx, TensorShapeUtils::IsMatrix(sequence_length.shape()), - errors::InvalidArgument("sequence_length must be a matrix, saw shape: ", - sequence_length.shape().DebugString())); - OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1), - errors::InvalidArgument( - "Inconsistent batch sizes: sequence_length.shape[0] (", - sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (", - step_ids_shape.dim_size(1), ")")); - OP_REQUIRES(ctx, sequence_length.dim_size(1) == step_ids_shape.dim_size(2), + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(max_sequence_lengths.shape()), errors::InvalidArgument( - "Inconsistent batch sizes: sequence_length.shape[1] (", - sequence_length.dim_size(1), ") != ", "step_ids.shape[2] (", - step_ids_shape.dim_size(2), ")")); + "max_sequence_lengths must be a vector, saw shape: ", + max_sequence_lengths.shape().DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(end_token.shape()), + errors::InvalidArgument("end_token must be a scalar, saw shape: ", + end_token.shape().DebugString())); OP_REQUIRES( ctx, step_ids_shape == parent_ids.shape(), errors::InvalidArgument( "step_ids.shape must match parent_ids.shape. but shapes are: ", step_ids_shape.DebugString(), " and ", parent_ids.shape().DebugString())); + OP_REQUIRES( + ctx, + step_ids_shape.dim_size(1) == max_sequence_lengths.shape().dim_size(0), + errors::InvalidArgument("batch size dimensions step_ids.shape[1] and " + "max_seqeuence_lengths.shape[0] must match. " + "but shapes are: ", + step_ids_shape.DebugString(), " and ", + max_sequence_lengths.shape().DebugString())); Tensor* beams; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams)); typename TTypes::ConstTensor step_ids_t = step_ids.tensor(); typename TTypes::ConstTensor parent_ids_t = parent_ids.tensor(); - typename TTypes::ConstMatrix seq_len_t = sequence_length.matrix(); + typename TTypes::ConstVec max_seq_lens_t = + max_sequence_lengths.vec(); + typename TTypes::ConstScalar end_token_t = end_token.scalar(); typename TTypes::Tensor beams_t = beams->tensor(); + const T end_token_value = end_token_t(); functor::GatherTree()(ctx, device, step_ids_t, parent_ids_t, - seq_len_t, beams_t); + max_seq_lens_t, end_token_value, beams_t); } }; @@ -99,27 +105,29 @@ namespace functor { template <> struct GatherTree { void operator()(OpKernelContext* ctx, const CPUDevice& d, - typename TTypes::ConstTensor step_ids, - typename TTypes::ConstTensor parent_ids, - typename TTypes::ConstMatrix sequence_length, - typename TTypes::Tensor beams) { - const int64 max_time = parent_ids.dimension(0); - const int64 batch_size = parent_ids.dimension(1); - const int64 beam_width = parent_ids.dimension(2); - beams.setConstant(-1); - - auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) { + TTypes::ConstTensor step_ids, + TTypes::ConstTensor parent_ids, + TTypes::ConstVec max_sequence_lengths, + const int32 end_token, TTypes::Tensor beams) { + const int32 max_time = parent_ids.dimension(0); + const int32 batch_size = parent_ids.dimension(1); + const int32 beam_width = parent_ids.dimension(2); + beams.setConstant(end_token); + + auto DoWork = [&, ctx, end_token](int start_batch_beam, + int limit_batch_beam) { for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) { const int32 batch = i / beam_width; const int32 beam = i % beam_width; - int32 seq_len_b = sequence_length(batch, beam); - if (seq_len_b <= 0) { + const int32 max_seq_len_b = + Eigen::numext::mini(max_time, max_sequence_lengths(batch)); + if (max_seq_len_b <= 0) { continue; } - beams(seq_len_b - 1, batch, beam) = - step_ids(seq_len_b - 1, batch, beam); - int32 parent = parent_ids(seq_len_b - 1, batch, beam); - for (int32 level = seq_len_b - 2; level >= 0; --level) { + beams(max_seq_len_b - 1, batch, beam) = + step_ids(max_seq_len_b - 1, batch, beam); + int32 parent = parent_ids(max_seq_len_b - 1, batch, beam); + for (int32 level = max_seq_len_b - 2; level >= 0; --level) { if (parent < 0 || parent > beam_width) { ctx->SetStatus( errors::InvalidArgument("Saw invalid parent id ", parent, @@ -130,6 +138,17 @@ struct GatherTree { beams(level, batch, beam) = step_ids(level, batch, parent); parent = parent_ids(level, batch, parent); } + // Not necessary when using a BeamSearchDecoder, but necessary + // when a user feeds in possibly broken trajectory (i.e., non-eos + // entries in a beam following eos entries). + bool finished = false; + for (int32 time = 0; time < max_seq_len_b; ++time) { + if (finished) { + beams(time, batch, beam) = end_token; + } else if (beams(time, batch, beam) == end_token) { + finished = true; + } + } } }; // Guesstimate of cost; ~5 lookup/store/compare per inner beam @@ -137,7 +156,7 @@ struct GatherTree { const int64 batch_beam_cost = Eigen::TensorOpCost::DivCost() + 6 * Eigen::TensorOpCost::AddCost() + - max_time * (5 * Eigen::TensorOpCost::AddCost()); + 2 * max_time * (5 * Eigen::TensorOpCost::AddCost()); auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); Shard(worker_threads.num_threads, worker_threads.workers, batch_size * beam_width, batch_beam_cost, DoWork); @@ -148,24 +167,26 @@ struct GatherTree { #if GOOGLE_CUDA namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void GatherTree::operator()( \ - OpKernelContext* ctx, const GPUDevice& d, \ - typename TTypes::ConstTensor step_ids, \ - typename TTypes::ConstTensor parent_ids, \ - typename TTypes::ConstMatrix sequence_length, \ - typename TTypes::Tensor beams); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void GatherTree::operator()( \ + OpKernelContext* ctx, const GPUDevice& d, \ + typename TTypes::ConstTensor step_ids, \ + typename TTypes::ConstTensor parent_ids, \ + TTypes::ConstVec max_sequence_lengths, const T end_token, \ + typename TTypes::Tensor beams); \ extern template struct GatherTree; DECLARE_GPU_SPEC(int32); #undef DECLARE_GPU_SPEC } // end namespace functor -#define REGISTER_GPU_KERNEL(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("GatherTree").Device(DEVICE_GPU).TypeConstraint("T"), \ - GatherTreeOp); +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("GatherTree") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("end_token"), \ + GatherTreeOp); REGISTER_GPU_KERNEL(int32); #undef REGISTER_GPU_KERNEL diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h index 124d07264e75ac4ce7739dd3291abdabbb40a58f..693b02dc437afdf14c38e4224c5469bb3e569540 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h @@ -31,8 +31,8 @@ struct GatherTree { void operator()(OpKernelContext* ctx, const Device& d, typename TTypes::ConstTensor step_ids, typename TTypes::ConstTensor parent_ids, - typename TTypes::ConstMatrix sequence_length, - typename TTypes::Tensor beams); + TTypes::ConstVec max_sequence_lengths, + const T end_token, typename TTypes::Tensor beams); }; } // namespace functor diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc index ee68b55d20214c207597750e083a63e94ebdc0a0..bc28d492fe1a25afe0d0783539aa9e759e7b703f 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc @@ -29,30 +29,50 @@ template __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time, const int32 beam_width, const T* step_ids, const T* parent_ids, - const T* sequence_length, T* beams) { + const int32* max_sequence_lengths, + const T end_token, T* beams) { CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) { const int32 batch = i / beam_width; const int32 beam = i % beam_width; - const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam); - if (seq_len_b <= 0) continue; + const int32 max_seq_len_b = + Eigen::numext::mini(max_time, ldg(max_sequence_lengths + batch)); + if (max_seq_len_b <= 0) { + continue; + } #define GET_IX(time_ix, beam_ix) \ (batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix)) - const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam); + const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam); beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix); int32 parent = ldg(parent_ids + initial_beam_ix); - for (int32 level = seq_len_b - 2; level >= 0; --level) { + bool found_bad = false; + for (int32 level = max_seq_len_b - 2; level >= 0; --level) { const int32 level_beam_ix = GET_IX(level, beam); const int32 level_parent_ix = GET_IX(level, parent); if (parent < 0 || parent > beam_width) { beams[level_beam_ix] = -1; parent = -1; + found_bad = true; } else { beams[level_beam_ix] = ldg(step_ids + level_parent_ix); parent = ldg(parent_ids + level_parent_ix); } } + // Not necessary when using a BeamSearchDecoder, but necessary + // when a user feeds in possibly broken trajectory (i.e., non-eos + // entries in a beam following eos entries). + if (!found_bad) { + bool finished = false; + for (int32 time = 0; time < max_seq_len_b; ++time) { + const int32 level_beam_ix = GET_IX(time, beam); + if (finished) { + beams[level_beam_ix] = end_token; + } else if (beams[level_beam_ix] == end_token) { + finished = true; + } + } + } #undef GET_IX } } @@ -62,20 +82,23 @@ struct GatherTree { void operator()(OpKernelContext* ctx, const GPUDevice& d, typename TTypes::ConstTensor step_ids, typename TTypes::ConstTensor parent_ids, - typename TTypes::ConstMatrix sequence_length, - typename TTypes::Tensor beams) { + TTypes::ConstVec max_sequence_length, + const T end_token, typename TTypes::Tensor beams) { const int32 max_time = parent_ids.dimension(0); const int32 batch_size = parent_ids.dimension(1); const int32 beam_width = parent_ids.dimension(2); - // First kernel launch to zero things out - beams.device(d) = beams.constant(T(-1)); + // First kernel launch to "zero" things out + beams.device(d) = beams.constant(end_token); CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); // clang-format off GatherTreeOpKernel <<>>( batch_size, max_time, beam_width, - step_ids.data(), parent_ids.data(), sequence_length.data(), + step_ids.data(), + parent_ids.data(), + max_sequence_length.data(), + end_token, beams.data()); // clang-format on } diff --git a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc index 6c445cd4606381ed56d91000bc5e42d874ca0c5c..71539b6f592f0c8e53c4bb3801d1e35f34814966 100644 --- a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc @@ -25,27 +25,27 @@ using shape_inference::ShapeHandle; REGISTER_OP("GatherTree") .Input("step_ids: T") .Input("parent_ids: T") - .Input("sequence_length: T") + .Input("max_sequence_lengths: int32") + .Input("end_token: T") .Output("beams: T") .Attr("T: {int32}") .SetShapeFn([](InferenceContext* c) { - ShapeHandle step_ids, parent_ids, sequence_length; + ShapeHandle step_ids, parent_ids, max_sequence_lengths, end_token; // step_ids, parent_ids, and output are all shaped: // [max_time, batch_size, beam_width]. - // sequence_length is shaped [batch_size, beam_width]. + // max_sequence_length is shaped [batch_size] and end_token is a scalar. TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids)); TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &sequence_length)); - - DimensionHandle batch_size = c->Dim(step_ids, 1); - DimensionHandle beam_width = c->Dim(step_ids, 2); - + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max_sequence_lengths)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &end_token)); TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids)); + DimensionHandle batch_size = c->Dim(step_ids, 1); TF_RETURN_IF_ERROR( - c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size)); - TF_RETURN_IF_ERROR( - c->Merge(beam_width, c->Dim(sequence_length, 1), &beam_width)); + c->Merge(batch_size, c->Dim(max_sequence_lengths, 0), &batch_size)); + ShapeHandle step_ids_prefix = c->Matrix(c->Dim(step_ids, 0), batch_size); + TF_RETURN_IF_ERROR(c->MergePrefix(step_ids, step_ids_prefix, &step_ids, + &step_ids_prefix)); c->set_output(0, step_ids); return tensorflow::Status::OK(); @@ -53,15 +53,19 @@ REGISTER_OP("GatherTree") .Doc(R"doc( Calculates the full beams from the per-step ids and parent beam ids. -This op implements the following mathematical equations: +On CPU, if an out of bound parent id is found, an error is returned. +On GPU, if an out of bound parent id is found, a -1 is stored in the +corresponding output value and the execution for that beam returns early. + +For a given beam, past the time step containing the first decoded `end_token` +all values are filled in with `end_token`. -```python -TODO(ebrevdo): fill in -``` +TODO(ebrevdo): fill in the remainder of this docstring. step_ids: `[max_time, batch_size, beam_width]`. parent_ids: `[max_time, batch_size, beam_width]`. -sequence_length: `[batch_size, beam_width]`. +max_sequence_lengths: `[batch_size]`. +end_token: `[]`. beams: `[max_time, batch_size, beam_width]`. )doc"); diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 2caeb9eb614382c815984391df87a70516f519b2..d2beac5f31460ec1c0d978a9f6fcd0e0f09cb9b4 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -54,15 +54,18 @@ class TestGatherTree(test.TestCase): [[0, 0, 0], [1, 2, 0], [2, 1, 1]]], dtype=np.int32).transpose([1, 0, 2]) - # sequence_lengths is shaped (batch_size = 2, beam_width = 3) - sequence_lengths = [[3, 3, 3], [3, 3, 3]] + # sequence_lengths is shaped (batch_size = 3) + max_sequence_lengths = [3, 3] expected_result = np.array( [[[2, 2, 2], [6, 5, 6], [7, 8, 9]], [[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2]) res = beam_search_ops.gather_tree( - predicted_ids, parent_ids, sequence_lengths) + predicted_ids, + parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=11) with self.test_session() as sess: res_ = sess.run(res) @@ -80,8 +83,7 @@ class TestEosMasking(test.TestCase): ]) eos_token = 0 - previously_finished = constant_op.constant( - [[0, 1, 0], [0, 1, 1]], dtype=dtypes.float32) + previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool) masked = beam_search_decoder._mask_probs(probs, eos_token, previously_finished) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py index 50cccf392fdac75f551b180987aff0b31da0893e..277c5b6ef76bce8d59e47cf0026c6e2b1d5cf1e2 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function # pylint: enable=unused-import +import itertools + import numpy as np from tensorflow.contrib.seq2seq.python.ops import beam_search_ops @@ -34,31 +36,37 @@ class GatherTreeTest(test.TestCase): def testGatherTreeOne(self): # (max_time = 4, batch_size = 1, beams = 3) + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [[3, 3, 3]] - expected_result = _transpose_batch_time( - [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + max_sequence_lengths = [3] + expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9], + [10, 10, 10]]]) beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) + step_ids=step_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=end_token) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) def testBadParentValuesOnCPU(self): # (batch_size = 1, max_time = 4, beams = 3) # bad parent in beam 1 time 1 + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [[3, 3, 3]] + max_sequence_lengths = [3] with ops.device("/cpu:0"): beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) + step_ids=step_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=end_token) with self.test_session(): with self.assertRaisesOpError( r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): @@ -71,82 +79,63 @@ class GatherTreeTest(test.TestCase): return # (max_time = 4, batch_size = 1, beams = 3) # bad parent in beam 1 time 1; appears as a negative index at time 0 + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [[3, 3, 3]] - expected_result = _transpose_batch_time( - [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + max_sequence_lengths = [3] + expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9], + [10, 10, 10]]]) with ops.device("/device:GPU:0"): beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) + step_ids=step_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=end_token) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) def testGatherTreeBatch(self): - # sequence_length is [batch_size, beam_width] = [4, 5] - sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5] + batch_size = 10 + beam_width = 15 + max_time = 8 + max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0] + end_token = 5 with self.test_session(use_gpu=True): - # (max_time = 4, batch_size = 4, beam_width = 5) - step_ids = _transpose_batch_time( - [[[3, 4, 0, 4, 0], - [4, 2, 0, 3, 1], - [1, 1, 3, 2, 2], - [3, 1, 2, 3, 4]], - [[3, 4, 0, 4, 0], - [4, 2, 0, 3, 1], - [1, 1, 3, 2, 2], - [3, 1, 2, 3, 4]], - [[1, 2, 3, 4, 2], - [2, 1, 1, 3, 2], - [3, 0, 1, 0, 0], - [3, 4, 0, 2, 4]], - [[0, 2, 2, 3, 1], - [3, 2, 2, 2, 3], - [3, 4, 3, 0, 3], - [1, 2, 2, 2, 4]]]) - parent_ids = _transpose_batch_time( - [[[4, 2, 4, 3, 4], - [3, 4, 0, 2, 0], - [3, 1, 3, 2, 2], - [0, 2, 1, 4, 2]], - [[4, 2, 4, 3, 4], - [3, 4, 0, 2, 0], - [3, 1, 3, 2, 2], - [0, 2, 1, 4, 2]], - [[3, 0, 0, 4, 0], - [1, 2, 4, 2, 2], - [4, 4, 0, 3, 0], - [2, 4, 4, 3, 0]], - [[3, 1, 4, 1, 3], - [3, 2, 4, 0, 4], - [1, 0, 1, 4, 2], - [0, 3, 2, 0, 1]]]) - expected_beams = _transpose_batch_time( - [[[-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1]], - [[3, 4, 0, 4, 0], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1]], - [[2, 3, 2, 3, 3], - [2, 1, 1, 3, 2], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1]], - [[2, 3, 2, 1, 1], - [2, 3, 2, 3, 2], - [3, 4, 3, 0, 3], - [-1, -1, -1, -1, -1]]]) + step_ids = np.random.randint( + 0, high=end_token + 1, size=(max_time, batch_size, beam_width)) + parent_ids = np.random.randint( + 0, high=beam_width - 1, size=(max_time, batch_size, beam_width)) beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) - self.assertAllEqual(expected_beams, beams.eval()) + step_ids=step_ids.astype(np.int32), + parent_ids=parent_ids.astype(np.int32), + max_sequence_lengths=max_sequence_lengths, + end_token=end_token) + + self.assertEqual((max_time, batch_size, beam_width), beams.shape) + beams_value = beams.eval() + for b in range(batch_size): + # Past max_sequence_lengths[b], we emit all end tokens. + b_value = beams_value[max_sequence_lengths[b]:, b, :] + self.assertAllClose(b_value, end_token * np.ones_like(b_value)) + for batch, beam in itertools.product( + range(batch_size), range(beam_width)): + v = np.squeeze(beams_value[:, batch, beam]) + if end_token in v: + found_bad = np.where(v == -1)[0] + self.assertEqual(0, len(found_bad)) + found = np.where(v == end_token)[0] + found = found[0] # First occurrence of end_token. + # If an end_token is found, everything before it should be a + # valid id and everything after it should be -1. + if found > 0: + self.assertAllEqual( + v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool)) + self.assertAllClose(v[found + 1:], + end_token * np.ones_like(v[found + 1:])) if __name__ == "__main__": diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index e22912ac5c9e378587d092ae2bed56929fe2a8e7..5be0c92243da10af438be97fab982515266be1de 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -20,9 +20,10 @@ from __future__ import print_function import collections +import numpy as np + from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -252,6 +253,20 @@ class BeamSearchDecoder(decoder.Decoder): output_shape_with_unknown_batch) return nest.map_structure(lambda s: s[1:], layer_output_shape) + @property + def tracks_own_finished(self): + """The BeamSearchDecoder shuffles its beams and their finished state. + + For this reason, it conflicts with the `dynamic_decode` function's + tracking of finished states. Setting this property to true avoids + early stopping of decoding due to mismanagement of the finished state + in `dynamic_decode`. + + Returns: + `True`. + """ + return True + @property def output_size(self): # Return the cell output and the id @@ -302,15 +317,23 @@ class BeamSearchDecoder(decoder.Decoder): output. sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. The sequence lengths determined for each beam during decode. + **NOTE** These are ignored; the updated sequence lengths are stored in + `final_state.lengths`. Returns: - outputs: An instance of FinalBeamSearchDecoderOutput where the + outputs: An instance of `FinalBeamSearchDecoderOutput` where the predicted_ids are the result of calling _gather_tree. - final_state: The same input instance of BeamSearchDecoderState. + final_state: The same input instance of `BeamSearchDecoderState`. """ + del sequence_lengths + # Get max_sequence_length across all beams for each batch. + max_sequence_lengths = math_ops.to_int32( + math_ops.reduce_max(final_state.lengths, axis=1)) predicted_ids = beam_search_ops.gather_tree( - outputs.predicted_ids, outputs.parent_ids, - sequence_length=sequence_lengths) + outputs.predicted_ids, + outputs.parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=self._end_token) outputs = FinalBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state @@ -390,17 +413,17 @@ class BeamSearchDecoder(decoder.Decoder): We do this so that we can use nest and not run into problems with shapes. Args: - t: Tensor of dimension [batch_size*beam_width, s] - s: Tensor, Python int, or TensorShape. + t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`. + s: `Tensor`, Python int, or `TensorShape`. Returns: - Either a reshaped version of t with dimension - [batch_size, beam_width, s] if t's first dimension is of size - batch_size*beam_width or t if not. + If `t` is a matrix or higher order tensor, then the return value is + `t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is + returned unchanged. Raises: - TypeError: If t is an instance of TensorArray. - ValueError: If the rank of t is not statically known. + TypeError: If `t` is an instance of `TensorArray`. + ValueError: If the rank of `t` is not statically known. """ _check_maybe(t) if t.shape.ndims >= 1: @@ -411,19 +434,19 @@ class BeamSearchDecoder(decoder.Decoder): def _maybe_merge_batch_beams(self, t, s): """Splits the tensor from a batch by beams into a batch of beams. - More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We - reshape this into [batch_size, beam_width, s] + More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`, + then we reshape it to `[batch_size, beam_width] + s`. Args: - t: Tensor of dimension [batch_size*beam_width, s] - s: Tensor, Python int, or TensorShape. + t: `Tensor` of dimension `[batch_size * beam_width] + s`. + s: `Tensor`, Python int, or `TensorShape`. Returns: - A reshaped version of t with dimension [batch_size, beam_width, s]. + A reshaped version of t with shape `[batch_size, beam_width] + s`. Raises: - TypeError: If t is an instance of TensorArray. - ValueError: If the rank of t is not statically known. + TypeError: If `t` is an instance of `TensorArray`. + ValueError: If the rank of `t` is not statically known. """ _check_maybe(t) if t.shape.ndims >= 2: @@ -521,14 +544,12 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, # Calculate the continuation lengths by adding to all continuing beams. vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1] lengths_to_add = array_ops.one_hot( - indices=array_ops.tile( - array_ops.reshape(end_token, [1, 1]), [batch_size, beam_width]), + indices=array_ops.fill([batch_size, beam_width], end_token), depth=vocab_size, - on_value=constant_op.constant(0, dtype=dtypes.int64), - off_value=constant_op.constant(1, dtype=dtypes.int64), + on_value=np.int64(0), off_value=np.int64(1), dtype=dtypes.int64) - add_mask = (1 - math_ops.to_int64(previously_finished)) - lengths_to_add = array_ops.expand_dims(add_mask, 2) * lengths_to_add + add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished)) + lengths_to_add *= array_ops.expand_dims(add_mask, 2) new_prediction_lengths = ( lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) @@ -589,12 +610,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, name="next_beam_finished") # Calculate the length of the next predictions. - # 1. Finished beams remain unchanged - # 2. Beams that are now finished (EOS predicted) remain unchanged - # 3. Beams that are not yet finished have their length increased by 1 - lengths_to_add = math_ops.to_int64( - math_ops.not_equal(next_word_ids, end_token)) - lengths_to_add = (1 - math_ops.to_int64(next_finished)) * lengths_to_add + # 1. Finished beams remain unchanged. + # 2. Beams that are now finished (EOS predicted) have their length + # increased by 1. + # 3. Beams that are not yet finished have their length increased by 1. + lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished)) next_prediction_len = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=beam_state.lengths, @@ -652,13 +672,20 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight): def _length_penalty(sequence_lengths, penalty_factor): """Calculates the length penalty. See https://arxiv.org/abs/1609.08144. + Returns the length penalty tensor: + ``` + [(5+sequence_lengths)/6]**penalty_factor + ``` + where all operations are performed element-wise. + Args: - sequence_lengths: The sequence length of all hypotheses, a tensor - of shape [beam_size, vocab_size]. + sequence_lengths: `Tensor`, the sequence lengths of each hypotheses. penalty_factor: A scalar that weights the length penalty. Returns: - The length penalty factor, a tensor fo shape [beam_size]. + If the penalty is `0`, returns the scalar `1.0`. Otherwise returns + the length penalty factor, a tensor with the same shape as + `sequence_lengths`. """ penalty_factor = ops.convert_to_tensor(penalty_factor, name="penalty_factor") penalty_factor.set_shape(()) # penalty should be a scalar. @@ -680,8 +707,7 @@ def _mask_probs(probs, eos_token, finished): eos_token: An int32 id corresponding to the EOS token to allocate probability to. finished: A boolean tensor of shape `[batch_size, beam_width]` that - specifies which - elements in the beam are finished already. + specifies which elements in the beam are finished already. Returns: A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished @@ -689,10 +715,6 @@ def _mask_probs(probs, eos_token, finished): probability on the EOS token. """ vocab_size = array_ops.shape(probs)[2] - finished_mask = array_ops.expand_dims( - math_ops.to_float(1. - math_ops.to_float(finished)), 2) - # These examples are not finished and we leave them - non_finished_examples = finished_mask * probs # All finished examples are replaced with a vector that has all # probability on EOS finished_row = array_ops.one_hot( @@ -701,8 +723,13 @@ def _mask_probs(probs, eos_token, finished): dtype=probs.dtype, on_value=0., off_value=probs.dtype.min) - finished_examples = (1. - finished_mask) * finished_row - return finished_examples + non_finished_examples + finished_probs = array_ops.tile( + array_ops.reshape(finished_row, [1, 1, -1]), + array_ops.concat([array_ops.shape(finished), [1]], 0)) + finished_mask = array_ops.tile( + array_ops.expand_dims(finished, 2), [1, 1, vocab_size]) + + return array_ops.where(finished_mask, finished_probs, probs) def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index fbe53fc60ada85c40970870c6d0bdb93d17ea6d4..f14974b9d5ca8cbcfd9f91086ca0a90ceff48f43 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -100,16 +100,36 @@ class Decoder(object): Returns: `(outputs, next_state, next_inputs, finished)`: `outputs` is an object - containing the decoder output, `next_state` is a (structure of) state tensors - and TensorArrays, `next_inputs` is the tensor that should be used as input for - the next step, `finished` is a boolean tensor telling whether the sequence - is complete, for each sequence in the batch. + containing the decoder output, `next_state` is a (structure of) state + tensors and TensorArrays, `next_inputs` is the tensor that should be used + as input for the next step, `finished` is a boolean tensor telling whether + the sequence is complete, for each sequence in the batch. """ raise NotImplementedError def finalize(self, outputs, final_state, sequence_lengths): raise NotImplementedError + @property + def tracks_own_finished(self): + """Describes whether the Decoder keeps track of finished states. + + Most decoders will emit a true/false `finished` value independently + at each time step. In this case, the `dynamic_decode` function keeps track + of which batch entries are already finished, and performs a logical OR to + insert new batches to the finished set. + + Some decoders, however, shuffle batches / beams between time steps and + `dynamic_decode` will mix up the finished state across these entries because + it does not track the reshuffle across time steps. In this case, it is + up to the decoder to declare that it will keep track of its own finished + state by setting this property to `True`. + + Returns: + Python bool. + """ + return False + def _create_zero_outputs(size, dtype, batch_size): """Create a zero outputs Tensor structure.""" @@ -232,7 +252,10 @@ def dynamic_decode(decoder, """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) - next_finished = math_ops.logical_or(decoder_finished, finished) + if decoder.tracks_own_finished: + next_finished = decoder_finished + else: + next_finished = math_ops.logical_or(decoder_finished, finished) if maximum_iterations is not None: next_finished = math_ops.logical_or( next_finished, time + 1 >= maximum_iterations) diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 0a8ec7710976843d41769448483abc3ee236ee40..2204b684ac993cd82e69b3fd74801bff610b5fd4 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -5,7 +5,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "cuda_py_tests") -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test") # @unused py_library( name = "signal_py", diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index 094568389cfdd2fd83b939cb9242694391f3844b..0544404e9e252cca6d3650b805b91be25d705eea 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -207,6 +207,76 @@ class Tensor(ItemHandler): return tensor +class LookupTensor(Tensor): + """An ItemHandler that returns a parsed Tensor, the result of a lookup.""" + + def __init__(self, + tensor_key, + table, + shape_keys=None, + shape=None, + default_value=''): + """Initializes the LookupTensor handler. + + See Tensor. Simply calls a vocabulary (most often, a label mapping) lookup. + + Args: + tensor_key: the name of the `TFExample` feature to read the tensor from. + table: A tf.lookup table. + shape_keys: Optional name or list of names of the TF-Example feature in + which the tensor shape is stored. If a list, then each corresponds to + one dimension of the shape. + shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is + reshaped accordingly. + default_value: The value used when the `tensor_key` is not found in a + particular `TFExample`. + + Raises: + ValueError: if both `shape_keys` and `shape` are specified. + """ + self._table = table + super(LookupTensor, self).__init__(tensor_key, shape_keys, shape, + default_value) + + def tensors_to_item(self, keys_to_tensors): + unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors) + return self._table.lookup(unmapped_tensor) + + +class BackupHandler(ItemHandler): + """An ItemHandler that tries two ItemHandlers in order.""" + + def __init__(self, handler, backup): + """Initializes the BackupHandler handler. + + If the first Handler's tensors_to_item returns a Tensor with no elements, + the second Handler is used. + + Args: + handler: The primary ItemHandler. + backup: The backup ItemHandler. + + Raises: + ValueError: if either is not an ItemHandler. + """ + if not isinstance(handler, ItemHandler): + raise ValueError('Primary handler is of type %s instead of ItemHandler' + % type(handler)) + if not isinstance(backup, ItemHandler): + raise ValueError('Backup handler is of type %s instead of ItemHandler' + % type(backup)) + self._handler = handler + self._backup = backup + super(BackupHandler, self).__init__(handler.keys + backup.keys) + + def tensors_to_item(self, keys_to_tensors): + item = self._handler.tensors_to_item(keys_to_tensors) + return control_flow_ops.cond( + pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0), + true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors), + false_fn=lambda: item) + + class SparseTensor(ItemHandler): """An ItemHandler for SparseTensors.""" diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py index 99f63134870c38fbd5c0f72b71ab55db4698b824..d783d4fef42bb2acffe7eb8b155c5efaed7896d9 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import image_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test @@ -751,10 +752,14 @@ class TFExampleDecoderTest(test.TestCase): serialized_example = array_ops.reshape(serialized_example, shape=[]) keys_to_features = { - 'image/object/bbox/ymin': parsing_ops.FixedLenSequenceFeature([], dtypes.float32, allow_missing=True), - 'image/object/bbox/xmin': parsing_ops.FixedLenSequenceFeature([], dtypes.float32, allow_missing=True), - 'image/object/bbox/ymax': parsing_ops.FixedLenSequenceFeature([], dtypes.float32, allow_missing=True), - 'image/object/bbox/xmax': parsing_ops.FixedLenSequenceFeature([], dtypes.float32, allow_missing=True), + 'image/object/bbox/ymin': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), + 'image/object/bbox/xmin': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), + 'image/object/bbox/ymax': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), + 'image/object/bbox/xmax': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), } items_to_handlers = { @@ -807,6 +812,87 @@ class TFExampleDecoderTest(test.TestCase): self.assertAllEqual(np.squeeze(output_image[0, :, :, :]), image) self.assertAllEqual(np.squeeze(output_image[1, :, :, :]), image) + def testDecodeExampleWithLookup(self): + + example = example_pb2.Example(features=feature_pb2.Features(feature={ + 'image/object/class/text': self._BytesFeature( + np.array(['cat', 'dog', 'guinea pig'])), + })) + serialized_example = example.SerializeToString() + # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 + table = lookup_ops.index_table_from_tensor( + constant_op.constant(['dog', 'guinea pig', 'cat'])) + + with self.test_session() as sess: + sess.run(lookup_ops.tables_initializer()) + + serialized_example = array_ops.reshape(serialized_example, shape=[]) + + keys_to_features = { + 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), + } + + items_to_handlers = { + 'labels': + tfexample_decoder.LookupTensor('image/object/class/text', table), + } + + decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, + items_to_handlers) + obtained_class_ids = decoder.decode(serialized_example)[0].eval() + + self.assertAllClose([2, 0, 1], obtained_class_ids) + + def testDecodeExampleWithBackupHandlerLookup(self): + + example1 = example_pb2.Example( + features=feature_pb2.Features( + feature={ + 'image/object/class/text': + self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])), + 'image/object/class/label': + self._EncodedInt64Feature(np.array([42, 10, 900])) + })) + example2 = example_pb2.Example( + features=feature_pb2.Features( + feature={ + 'image/object/class/text': + self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])), + })) + example3 = example_pb2.Example( + features=feature_pb2.Features( + feature={ + 'image/object/class/label': + self._EncodedInt64Feature(np.array([42, 10, 901])) + })) + # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 + table = lookup_ops.index_table_from_tensor( + constant_op.constant(['dog', 'guinea pig', 'cat'])) + keys_to_features = { + 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), + 'image/object/class/label': parsing_ops.VarLenFeature(dtypes.int64), + } + backup_handler = tfexample_decoder.BackupHandler( + handler=tfexample_decoder.Tensor('image/object/class/label'), + backup=tfexample_decoder.LookupTensor('image/object/class/text', table)) + items_to_handlers = { + 'labels': backup_handler, + } + decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, + items_to_handlers) + obtained_class_ids_each_example = [] + with self.test_session() as sess: + sess.run(lookup_ops.tables_initializer()) + for example in [example1, example2, example3]: + serialized_example = array_ops.reshape( + example.SerializeToString(), shape=[]) + obtained_class_ids_each_example.append( + decoder.decode(serialized_example)[0].eval()) + + self.assertAllClose([42, 10, 900], obtained_class_ids_each_example[0]) + self.assertAllClose([2, 0, 1], obtained_class_ids_each_example[1]) + self.assertAllClose([42, 10, 901], obtained_class_ids_each_example[2]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 5ee014a1f11a6b0d11857d209f27b134b737275d..def00b76184ba4e1fc630cd83d8e055448100562 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -552,7 +552,8 @@ def train(train_op, sync_optimizer=None, session_config=None, session_wrapper=None, - trace_every_n_steps=None): + trace_every_n_steps=None, + ignore_live_threads=False): """Runs a training loop using a TensorFlow supervisor. When the sync_optimizer is supplied, gradient updates are applied @@ -615,6 +616,9 @@ def train(train_op, trace_every_n_steps: produce and save a `Timeline` in Chrome trace format and add it to the summaries every `trace_every_n_steps`. If None, no trace information will be produced or saved. + ignore_live_threads: If `True` ignores threads that remain running after + a grace period when stopping the supervisor, instead of raising a + RuntimeError. Returns: the value of the loss function after training. @@ -772,7 +776,10 @@ def train(train_op, if logdir and sv.is_chief: logging.info('Finished training! Saving model to disk.') sv.saver.save(sess, sv.save_path, global_step=sv.global_step) - sv.stop(threads, close_summary_writer=True) + sv.stop( + threads, + close_summary_writer=True, + ignore_live_threads=ignore_live_threads) except errors.AbortedError: # Always re-run on AbortedError as it indicates a restart of one of the diff --git a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py index 9a36bdc2f9558220fa6cc47d5bb95d6e49a480f7..cd4d46aa07bfa92b8243f2f168fd1e4682ad70e2 100644 --- a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py +++ b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib import stateless +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops @@ -79,6 +80,21 @@ class StatelessOpsTest(test.TestCase): for s1, v1 in values: self.assertEqual(s0 == s1, np.all(v0 == v1)) + def testShapeType(self): + with self.test_session(use_gpu=True): + for shape_dtype in [dtypes.int32, dtypes.int64]: + seed_t = array_ops.placeholder(dtypes.int64, shape=[2]) + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + for stateless_op, _ in CASES: + for shape in (), (3,), (2, 5): + pure = stateless_op(constant_op.constant(shape, dtype=shape_dtype), + seed=seed_t) + values = [(seed, pure.eval(feed_dict={seed_t: seed})) + for seed in seeds] + for s0, v0 in values: + for s1, v1 in values: + self.assertEqual(s0 == s1, np.all(v0 == v1)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index d09ad48e10a0dfe37860d302567f6cc241135422..8cb5c3f3818a0a5929817efb17fd7c2784fb9ea6 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -25,6 +25,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":summary_ops", + ":summary_test_util", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:lib", @@ -43,15 +44,25 @@ py_library( deps = [ ":gen_summary_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:layers_base", "//tensorflow/python:summary_op_util", "//tensorflow/python:training", "//tensorflow/python/eager:context", ], ) +py_library( + name = "summary", + srcs = ["summary.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + ":summary_ops", + ], +) + filegroup( name = "all_files", srcs = glob( @@ -63,3 +74,17 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +# NOTE: target cannot be testonly because it needs to be in the pip +# package. Sigh. +py_library( + name = "summary_test_util", + srcs = ["summary_test_util.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + ], +) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..89031caadc1206461aab75dd9496fd764a367d37 --- /dev/null +++ b/tensorflow/contrib/summary/summary.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Contrib summary package. + +The operations in this package are safe to use with eager execution turned or on +off. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.contrib.summary.summary_ops import all_summary_ops +from tensorflow.contrib.summary.summary_ops import always_record_summaries +from tensorflow.contrib.summary.summary_ops import audio +from tensorflow.contrib.summary.summary_ops import create_summary_file_writer +from tensorflow.contrib.summary.summary_ops import generic +from tensorflow.contrib.summary.summary_ops import histogram +from tensorflow.contrib.summary.summary_ops import image +from tensorflow.contrib.summary.summary_ops import never_record_summaries +from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps +from tensorflow.contrib.summary.summary_ops import scalar +from tensorflow.contrib.summary.summary_ops import should_record_summaries +from tensorflow.contrib.summary.summary_ops import summary_writer_initializer_op diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index c8d0c14e1951a7c29eed096d2a2e9849c4326245..9c71bf7740c91dbb522464dc4abbf0a6ad31aeb2 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -24,21 +24,28 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.layers import utils +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import summary_op_util from tensorflow.python.training import training_util - +from tensorflow.python.util import tf_contextlib # Name for a collection which is expected to have at most a single boolean # Tensor. If this tensor is True the summary ops will record summaries. _SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries" +_SUMMARY_COLLECTION_NAME = "_SUMMARY_V2" +_SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2" + def should_record_summaries(): """Returns boolean Tensor which is true if summaries should be recorded.""" should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME) if not should_record_collection: - return constant_op.constant(False) + return False if len(should_record_collection) != 1: raise ValueError( "More than one tensor specified for whether summaries " @@ -47,22 +54,63 @@ def should_record_summaries(): # TODO(apassos) consider how to handle local step here. +@tf_contextlib.contextmanager def record_summaries_every_n_global_steps(n): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) - collection_ref[:] = [training_util.get_global_step() % n == 0] + old = collection_ref[:] + with ops.device("cpu:0"): + collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)] + yield + collection_ref[:] = old +@tf_contextlib.contextmanager def always_record_summaries(): """Sets the should_record_summaries Tensor to always true.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) - collection_ref[:] = [constant_op.constant(True)] + old = collection_ref[:] + collection_ref[:] = [True] + yield + collection_ref[:] = old +@tf_contextlib.contextmanager def never_record_summaries(): """Sets the should_record_summaries Tensor to always false.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) - collection_ref[:] = [constant_op.constant(False)] + old = collection_ref[:] + collection_ref[:] = [False] + yield + collection_ref[:] = old + + +class SummaryWriter(object): + """Encapsulates a summary writer.""" + + def __init__(self, resource): + self._resource = resource + + def __del__(self): + if context.in_eager_mode(): + resource_variable_ops.destroy_resource_op(self._resource) + + def set_as_default(self): + context.context().summary_writer_resource = self._resource + + @tf_contextlib.contextmanager + def as_default(self): + if self._resource is None: + yield + else: + old = context.context().summary_writer_resource + context.context().summary_writer_resource = self._resource + yield + # Flushes the summary writer in eager mode or in graph functions, but not + # in legacy graph mode (you're on your own there). + with ops.device("cpu:0"): + gen_summary_ops.flush_summary_writer(self._resource) + context.context().summary_writer_resource = old def create_summary_file_writer(logdir, @@ -70,22 +118,62 @@ def create_summary_file_writer(logdir, flush_secs=None, filename_suffix=None, name=None): - """Creates a summary file writer in the current context.""" - if max_queue is None: - max_queue = constant_op.constant(10) - if flush_secs is None: - flush_secs = constant_op.constant(120) - if filename_suffix is None: - filename_suffix = constant_op.constant("") - resource = gen_summary_ops.summary_writer(shared_name=name) - gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, - flush_secs, filename_suffix) - context.context().summary_writer_resource = resource + """Creates a summary file writer in the current context. + + Args: + logdir: a string, or None. If a string, creates a summary file writer + which writes to the directory named by the string. If None, returns + a mock object which acts like a summary writer but does nothing, + useful to use as a context manager. + max_queue: the largest number of summaries to keep in a queue; will + flush once the queue gets bigger than this. + flush_secs: the largest interval (in seconds) between flushes. + filename_suffix: optional suffix for the event file name. + name: name for the summary writer. + + Returns: + Either a summary writer or an empty object which can be used as a + summary writer. + """ + if logdir is None: + return SummaryWriter(None) + with ops.device("cpu:0"): + if max_queue is None: + max_queue = constant_op.constant(10) + if flush_secs is None: + flush_secs = constant_op.constant(120) + if filename_suffix is None: + filename_suffix = constant_op.constant("") + resource = gen_summary_ops.summary_writer(shared_name=name) + # TODO(apassos) ensure the initialization op runs when in graph mode; + # consider calling session.run here. + ops.add_to_collection( + _SUMMARY_WRITER_INIT_COLLECTION_NAME, + gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, + flush_secs, filename_suffix)) + return SummaryWriter(resource) def _nothing(): """Convenient else branch for when summaries do not record.""" - return False + return constant_op.constant(False) + + +def all_summary_ops(): + """Graph-mode only. Returns all summary ops.""" + if context.in_eager_mode(): + raise RuntimeError( + "tf.contrib.summary.all_summary_ops is only supported in graph mode.") + return ops.get_collection(_SUMMARY_COLLECTION_NAME) + + +def summary_writer_initializer_op(): + """Graph-mode only. Returns the list of ops to create all summary writers.""" + if context.in_eager_mode(): + raise RuntimeError( + "tf.contrib.summary.summary_writer_initializer_op is only " + "supported in graph mode.") + return ops.get_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME) def summary_writer_function(name, tensor, function, family=None): @@ -103,20 +191,27 @@ def summary_writer_function(name, tensor, function, family=None): def record(): with summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): - function(tag, scope) - return True + with ops.control_dependencies([function(tag, scope)]): + return constant_op.constant(True) - return control_flow_ops.cond( - should_record_summaries(), record, _nothing, name="") + if context.context().summary_writer_resource is None: + return control_flow_ops.no_op() + with ops.device("cpu:0"): + op = utils.smart_cond( + should_record_summaries(), record, _nothing, name="") + ops.add_to_collection(_SUMMARY_COLLECTION_NAME, op) + return op def generic(name, tensor, metadata, family=None): """Writes a tensor summary if possible.""" def function(tag, scope): - gen_summary_ops.write_summary(context.context().summary_writer_resource, - training_util.get_global_step(), tensor, - tag, metadata, name=scope) + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), array_ops.identity(tensor), + tag, metadata, name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -124,9 +219,11 @@ def scalar(name, tensor, family=None): """Writes a scalar summary if possible.""" def function(tag, scope): - gen_summary_ops.write_scalar_summary( + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_scalar_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, name=scope) + training_util.get_global_step(), tag, array_ops.identity(tensor), + name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -135,9 +232,11 @@ def histogram(name, tensor, family=None): """Writes a histogram summary if possible.""" def function(tag, scope): - gen_summary_ops.write_histogram_summary( + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_histogram_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, name=scope) + training_util.get_global_step(), tag, array_ops.identity(tensor), + name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -148,10 +247,12 @@ def image(name, tensor, bad_color=None, max_images=3, family=None): def function(tag, scope): if bad_color is None: bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) - gen_summary_ops.write_image_summary( + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_image_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, bad_color_, max_images, - name=scope) + training_util.get_global_step(), tag, array_ops.identity(tensor), + bad_color_, + max_images, name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -160,11 +261,12 @@ def audio(name, tensor, sample_rate, max_outputs, family=None): """Writes an audio summary if possible.""" def function(tag, scope): - gen_summary_ops.write_audio_summary( + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_audio_summary( context.context().summary_writer_resource, training_util.get_global_step(), tag, - tensor, + array_ops.identity(tensor), sample_rate=sample_rate, max_outputs=max_outputs, name=scope) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 6958ee8dd83600d130293322c8680b3c0c0c02b2..de7ae6ec277a97235617882a7cc7e469eaebe26c 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -17,16 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import tempfile from tensorflow.contrib.summary import summary_ops -from tensorflow.core.util import event_pb2 +from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import errors from tensorflow.python.framework import test_util -from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile from tensorflow.python.training import training_util @@ -40,61 +38,53 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0') def testShouldRecordSummary(self): - self.assertFalse(summary_ops.should_record_summaries().numpy()) - summary_ops.always_record_summaries() - self.assertTrue(summary_ops.should_record_summaries().numpy()) + self.assertFalse(summary_ops.should_record_summaries()) + with summary_ops.always_record_summaries(): + self.assertTrue(summary_ops.should_record_summaries()) def testSummaryOps(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0') - summary_ops.always_record_summaries() - summary_ops.generic('tensor', 1, '') - summary_ops.scalar('scalar', 2.0) - summary_ops.histogram('histogram', [1.0]) - summary_ops.image('image', [[[[1.0]]]]) - summary_ops.audio('audio', [[1.0]], 1.0, 1) - # The working condition of the ops is tested in the C++ test so we just - # test here that we're calling them correctly. - self.assertTrue(gfile.Exists(logdir)) + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t0').as_default(), summary_ops.always_record_summaries(): + summary_ops.generic('tensor', 1, '') + summary_ops.scalar('scalar', 2.0) + summary_ops.histogram('histogram', [1.0]) + summary_ops.image('image', [[[[1.0]]]]) + summary_ops.audio('audio', [[1.0]], 1.0, 1) + # The working condition of the ops is tested in the C++ test so we just + # test here that we're calling them correctly. + self.assertTrue(gfile.Exists(logdir)) def testDefunSummarys(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1') - summary_ops.always_record_summaries() - - @function.defun - def write(): - summary_ops.scalar('scalar', 2.0) + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t1').as_default(), summary_ops.always_record_summaries(): - write() + @function.defun + def write(): + summary_ops.scalar('scalar', 2.0) - self.assertTrue(gfile.Exists(logdir)) - files = gfile.ListDirectory(logdir) - self.assertEqual(len(files), 1) - records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) - self.assertEqual(len(records), 2) - event = event_pb2.Event() - event.ParseFromString(records[1]) - self.assertEqual(event.summary.value[0].simple_value, 2.0) + write() + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].simple_value, 2.0) def testSummaryName(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t2') - summary_ops.always_record_summaries() - - summary_ops.scalar('scalar', 2.0) - - self.assertTrue(gfile.Exists(logdir)) - files = gfile.ListDirectory(logdir) - self.assertEqual(len(files), 1) - records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) - self.assertEqual(len(records), 2) - event = event_pb2.Event() - event.ParseFromString(records[1]) - self.assertEqual(event.summary.value[0].tag, 'scalar') + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t2').as_default(), summary_ops.always_record_summaries(): + + summary_ops.scalar('scalar', 2.0) + + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'scalar') if __name__ == '__main__': diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..37b546d3ab3220f934ea3bf7ef8f5fe6ab29f683 --- /dev/null +++ b/tensorflow/contrib/summary/summary_test_util.py @@ -0,0 +1,41 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities to test summaries.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.core.util import event_pb2 +from tensorflow.python.lib.io import tf_record +from tensorflow.python.platform import gfile + + +def events_from_file(logdir): + """Returns all events in the single eventfile in logdir.""" + assert gfile.Exists(logdir) + files = gfile.ListDirectory(logdir) + assert len(files) == 1, "Found more than one file in logdir: %s" % files + records = list( + tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + result = [] + for r in records: + event = event_pb2.Event() + event.ParseFromString(r) + result.append(event) + return result diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index b6d57ef952777bc204f9534e60f2ce7de3687615..f80a34ece662d1e0b0ea1cb7616fa1b5b84731fa 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -235,9 +235,6 @@ class ProcessInputOp : public OpKernel { string serialized_proto; OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto)); input_spec_.ParseFromString(serialized_proto); - - data_set_ = std::unique_ptr( - new TensorDataSet(input_spec_, random_seed_)); } void Compute(OpKernelContext* context) override { @@ -249,8 +246,9 @@ class ProcessInputOp : public OpKernel { const Tensor& input_weights = context->input(7); const Tensor& leaf_ids_tensor = context->input(8); - data_set_->set_input_tensors(input_data, sparse_input_indices, - sparse_input_values, sparse_input_shape); + std::unique_ptr data_set(new TensorDataSet(input_spec_, 0)); + data_set->set_input_tensors(input_data, sparse_input_indices, + sparse_input_values, sparse_input_shape); FertileStatsResource* fertile_stats_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), @@ -264,7 +262,7 @@ class ProcessInputOp : public OpKernel { core::ScopedUnref unref_stats(fertile_stats_resource); core::ScopedUnref unref_tree(tree_resource); - const int32 num_data = data_set_->NumItems(); + const int32 num_data = data_set->NumItems(); auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; @@ -308,23 +306,23 @@ class ProcessInputOp : public OpKernel { // from a digits run on local desktop. Heuristics might be necessary // if it really matters that much. const int64 costPerUpdate = 1000; - auto update = [this, &target, &leaf_ids_tensor, &num_targets, + auto update = [this, &target, &leaf_ids_tensor, &num_targets, &data_set, fertile_stats_resource, &locks, &set_lock, &ready_to_split, num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); - UpdateStats(fertile_stats_resource, data_set_, target, num_targets, + UpdateStats(fertile_stats_resource, data_set, target, num_targets, leaf_ids_tensor, &locks, &set_lock, static_cast(start), static_cast(end), &ready_to_split); }; auto update_collated = [this, &target, &num_targets, fertile_stats_resource, tree_resource, &leaf_examples, &set_lock, - &ready_to_split, + &ready_to_split, &data_set, num_leaves](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_leaves); - UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set_, + UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set, target, num_targets, leaf_examples, &set_lock, static_cast(start), static_cast(end), &ready_to_split); @@ -350,7 +348,6 @@ class ProcessInputOp : public OpKernel { private: int32 random_seed_; tensorforest::TensorForestDataSpec input_spec_; - std::unique_ptr data_set_; TensorForestParams param_proto_; }; diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f0566322958a66eb3eb7cb9b47c7b94e70659e7d --- /dev/null +++ b/tensorflow/contrib/tensorboard/db/BUILD @@ -0,0 +1,36 @@ +# Description: +# TensorBoard database code. + +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "schema", + srcs = ["schema.cc"], + hdrs = ["schema.h"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core/lib/db:sqlite", + ], +) + +tf_cc_test( + name = "schema_test", + srcs = ["schema_test.cc"], + deps = [ + ":schema", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/db:sqlite", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["*"]), + visibility = ["//tensorflow:__pkg__"], +) diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5a8e02a9bb2d2148a857b32d30641dfa8c9b89d --- /dev/null +++ b/tensorflow/contrib/tensorboard/db/schema.cc @@ -0,0 +1,412 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/tensorboard/db/schema.h" + +namespace tensorflow { +namespace db { +namespace { + +class SqliteSchema { + public: + explicit SqliteSchema(Sqlite* db) : db_(db) {} + ~SqliteSchema() { db_ = nullptr; } + + /// \brief Creates Tensors table. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// tag_id: ID of associated Tag. + /// computed_time: Float UNIX timestamp with microsecond precision. + /// In the old summaries system that uses FileWriter, this is the + /// wall time around when tf.Session.run finished. In the new + /// summaries system, it is the wall time of when the tensor was + /// computed. On systems with monotonic clocks, it is calculated + /// by adding the monotonic run duration to Run.started_time. + /// This field is not indexed because, in practice, it should be + /// ordered the same or nearly the same as TensorIndex, so local + /// insertion sort might be more suitable. + /// step: User-supplied number, ordering this tensor in Tag. + /// If NULL then the Tag must have only one Tensor. + /// tensor: Can be an INTEGER (DT_INT64), FLOAT (DT_DOUBLE), or + /// BLOB. The structure of a BLOB is currently undefined, but in + /// essence it is a Snappy tf.TensorProto that spills over into + /// TensorChunks. + Status CreateTensorsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Tensors ( + rowid INTEGER PRIMARY KEY, + tag_id INTEGER NOT NULL, + computed_time REAL, + step INTEGER, + tensor BLOB + ) + )sql"); + } + + /// \brief Creates TensorChunks table. + /// + /// This table can be used to split up a tensor across many rows, + /// which has the advantage of not slowing down table scans on the + /// main table, allowing asynchronous fetching, minimizing copying, + /// and preventing large buffers from being allocated. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// tag_id: ID of associated Tag. + /// step: Same as corresponding Tensors.step. + /// sequence: 1-indexed sequence number for ordering chunks. Please + /// note that the 0th index is Tensors.tensor. + /// chunk: Bytes of next chunk in tensor. + Status CreateTensorChunksTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS TensorChunks ( + rowid INTEGER PRIMARY KEY, + tag_id INTEGER NOT NULL, + step INTEGER, + sequence INTEGER, + chunk BLOB + ) + )sql"); + } + + /// \brief Creates Tags table. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// tag_id: Permanent >0 unique ID. + /// run_id: Optional ID of associated Run. + /// tag_name: The tag field in summary.proto, unique across Run. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the wall time of when the row was inserted into the + /// DB. It may be used as a hint for an archival job. + /// metadata: Optional BLOB of SummaryMetadata proto. + /// display_name: Optional for GUI and defaults to tag_name. + /// summary_description: Optional markdown information. + Status CreateTagsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Tags ( + rowid INTEGER PRIMARY KEY, + run_id INTEGER, + tag_id INTEGER NOT NULL, + tag_name TEXT, + inserted_time DOUBLE, + metadata BLOB, + display_name TEXT, + description TEXT + ) + )sql"); + } + + /// \brief Creates Runs table. + /// + /// This table stores information about runs. Each row usually + /// represents a single attempt at training or testing a TensorFlow + /// model, with a given set of hyper-parameters, whose summaries are + /// written out to a single event logs directory with a monotonic step + /// counter. + /// + /// When a run is deleted from this table, TensorBoard should treat all + /// information associated with it as deleted, even if those rows in + /// different tables still exist. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// run_id: Permanent >0 unique ID. + /// experiment_id: Optional ID of associated Experiment. + /// run_name: User-supplied string, unique across Experiment. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the time the row was inserted into the database. It + /// does not change. + /// started_time: Float UNIX timestamp with µs precision. In the + /// old summaries system that uses FileWriter, this is + /// approximated as the first tf.Event.wall_time. In the new + /// summaries system, it is the wall time of when summary writing + /// started, from the perspective of whichever machine talks to + /// the database. This field will be mutated if the run is + /// restarted. + /// description: Optional markdown information. + /// graph: Snappy tf.GraphDef proto with node field cleared. That + /// field can be recreated using GraphNodes and NodeDefs. + Status CreateRunsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Runs ( + rowid INTEGER PRIMARY KEY, + experiment_id INTEGER, + run_id INTEGER NOT NULL, + run_name TEXT, + inserted_time REAL, + started_time REAL, + description TEXT, + graph BLOB + ) + )sql"); + } + + /// \brief Creates Experiments table. + /// + /// This table stores information about experiments, which are sets of + /// runs. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// user_id: Optional ID of associated User. + /// experiment_id: Permanent >0 unique ID. + /// experiment_name: User-supplied string, unique across User. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the time the row was inserted into the database. It + /// does not change. + /// started_time: Float UNIX timestamp with µs precision. This is + /// the MIN(experiment.started_time, run.started_time) of each + /// Run added to the database. + /// description: Optional markdown information. + Status CreateExperimentsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Experiments ( + rowid INTEGER PRIMARY KEY, + user_id INTEGER, + experiment_id INTEGER NOT NULL, + experiment_name TEXT, + inserted_time REAL, + started_time REAL, + description TEXT + ) + )sql"); + } + + /// \brief Creates Users table. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// user_id: Permanent >0 unique ID. + /// user_name: Unique user name. + /// email: Optional unique email address. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the time the row was inserted into the database. It + /// does not change. + Status CreateUsersTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Users ( + rowid INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + user_name TEXT, + email TEXT, + inserted_time REAL + ) + )sql"); + } + + /// \brief Creates NodeDefs table. + /// + /// This table stores NodeDef protos which define the GraphDef for a + /// Run. This functions like a hash table so rows can be shared by + /// multiple Runs in an Experiment. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// experiment_id: Optional int64 for grouping rows. + /// node_def_id: Permanent >0 unique ID. + /// fingerprint: Optional farmhash::Fingerprint64() of uncompressed + /// node_def bytes, coerced to int64. + /// node_def: BLOB containing a Snappy tf.NodeDef proto. + Status CreateNodeDefsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS NodeDefs ( + rowid INTEGER PRIMARY KEY, + experiment_id INTEGER, + node_def_id INTEGER NOT NULL, + fingerprint INTEGER, + node_def TEXT + ) + )sql"); + } + + /// \brief Creates RunNodeDefs table. + /// + /// Table mapping Runs to NodeDefs. This is used to recreate the node + /// field of the GraphDef proto. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// run_id: Mandatory ID of associated Run. + /// node_def_id: Mandatory ID of associated NodeDef. + Status CreateRunNodeDefsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS RunNodeDefs ( + rowid INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + node_def_id INTEGER NOT NULL + ) + )sql"); + } + + /// \brief Uniquely indexes (tag_id, step) on Tensors table. + Status CreateTensorIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS TensorIndex + ON Tensors (tag_id, step) + )sql"); + } + + /// \brief Uniquely indexes (tag_id, step, sequence) on TensorChunks table. + Status CreateTensorChunkIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS TensorChunkIndex + ON TensorChunks (tag_id, step, sequence) + )sql"); + } + + /// \brief Uniquely indexes tag_id on Tags table. + Status CreateTagIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS TagIdIndex + ON Tags (tag_id) + )sql"); + } + + /// \brief Uniquely indexes run_id on Runs table. + Status CreateRunIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS RunIdIndex + ON Runs (run_id) + )sql"); + } + + /// \brief Uniquely indexes experiment_id on Experiments table. + Status CreateExperimentIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS ExperimentIdIndex + ON Experiments (experiment_id) + )sql"); + } + + /// \brief Uniquely indexes user_id on Users table. + Status CreateUserIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS UserIdIndex + ON Users (user_id) + )sql"); + } + + /// \brief Uniquely indexes node_def_id on NodeDefs table. + Status CreateNodeDefIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS NodeDefIdIndex + ON NodeDefs (node_def_id) + )sql"); + } + + /// \brief Uniquely indexes (run_id, tag_name) on Tags table. + Status CreateTagNameIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS TagNameIndex + ON Tags (run_id, tag_name) + WHERE tag_name IS NOT NULL + )sql"); + } + + /// \brief Uniquely indexes (experiment_id, run_name) on Runs table. + Status CreateRunNameIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS RunNameIndex + ON Runs (experiment_id, run_name) + WHERE run_name IS NOT NULL + )sql"); + } + + /// \brief Uniquely indexes (user_id, experiment_name) on Experiments table. + Status CreateExperimentNameIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS ExperimentNameIndex + ON Experiments (user_id, experiment_name) + WHERE experiment_name IS NOT NULL + )sql"); + } + + /// \brief Uniquely indexes user_name on Users table. + Status CreateUserNameIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS UserNameIndex + ON Users (user_name) + WHERE user_name IS NOT NULL + )sql"); + } + + /// \brief Uniquely indexes email on Users table. + Status CreateUserEmailIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS UserEmailIndex + ON Users (email) + WHERE email IS NOT NULL + )sql"); + } + + /// \brief Indexes (experiment_id, fingerprint) on NodeDefs table. + Status CreateNodeDefFingerprintIndex() { + return Run(R"sql( + CREATE INDEX IF NOT EXISTS NodeDefFingerprintIndex + ON NodeDefs (experiment_id, fingerprint) + WHERE fingerprint IS NOT NULL + )sql"); + } + + /// \brief Uniquely indexes (run_id, node_def_id) on RunNodeDefs table. + Status CreateRunNodeDefIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS RunNodeDefIndex + ON RunNodeDefs (run_id, node_def_id) + )sql"); + } + + Status Run(const char* sql) { + auto stmt = db_->Prepare(sql); + TF_RETURN_WITH_CONTEXT_IF_ERROR(stmt->StepAndReset(), sql); + return Status::OK(); + } + + private: + Sqlite* db_; +}; + +} // namespace + +Status SetupTensorboardSqliteDb(Sqlite* db) { + SqliteSchema s(db); + TF_RETURN_IF_ERROR(s.CreateTensorsTable()); + TF_RETURN_IF_ERROR(s.CreateTensorChunksTable()); + TF_RETURN_IF_ERROR(s.CreateTagsTable()); + TF_RETURN_IF_ERROR(s.CreateRunsTable()); + TF_RETURN_IF_ERROR(s.CreateExperimentsTable()); + TF_RETURN_IF_ERROR(s.CreateUsersTable()); + TF_RETURN_IF_ERROR(s.CreateNodeDefsTable()); + TF_RETURN_IF_ERROR(s.CreateRunNodeDefsTable()); + TF_RETURN_IF_ERROR(s.CreateTensorIndex()); + TF_RETURN_IF_ERROR(s.CreateTensorChunkIndex()); + TF_RETURN_IF_ERROR(s.CreateTagIdIndex()); + TF_RETURN_IF_ERROR(s.CreateRunIdIndex()); + TF_RETURN_IF_ERROR(s.CreateExperimentIdIndex()); + TF_RETURN_IF_ERROR(s.CreateUserIdIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeDefIdIndex()); + TF_RETURN_IF_ERROR(s.CreateTagNameIndex()); + TF_RETURN_IF_ERROR(s.CreateRunNameIndex()); + TF_RETURN_IF_ERROR(s.CreateExperimentNameIndex()); + TF_RETURN_IF_ERROR(s.CreateUserNameIndex()); + TF_RETURN_IF_ERROR(s.CreateUserEmailIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeDefFingerprintIndex()); + TF_RETURN_IF_ERROR(s.CreateRunNodeDefIndex()); + return Status::OK(); +} + +} // namespace db +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorboard/db/schema.h b/tensorflow/contrib/tensorboard/db/schema.h new file mode 100644 index 0000000000000000000000000000000000000000..d3a6922d94a50b3499263c7f58299bf75a4f60ac --- /dev/null +++ b/tensorflow/contrib/tensorboard/db/schema.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_ +#define TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/db/sqlite.h" + +namespace tensorflow { +namespace db { + +/// \brief Creates TensorBoard SQLite tables and indexes. +/// +/// If they are already created, this has no effect. If schema +/// migrations are necessary, they will be performed with logging. +Status SetupTensorboardSqliteDb(Sqlite* db); + +} // namespace db +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_ diff --git a/tensorflow/contrib/tensorboard/db/schema_test.cc b/tensorflow/contrib/tensorboard/db/schema_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4302dda44764e9d9c2c9f8da7484675cb26d5a6 --- /dev/null +++ b/tensorflow/contrib/tensorboard/db/schema_test.cc @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/tensorboard/db/schema.h" + +#include + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace db { +namespace { + +TEST(SchemaTest, SmokeTestTensorboardSchema) { + std::unique_ptr db; + TF_ASSERT_OK(Sqlite::Open(":memory:", &db)); + TF_ASSERT_OK(SetupTensorboardSqliteDb(db.get())); +} + +} // namespace +} // namespace db +} // namespace tensorflow diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 015d0eba29f281d78ed6717271987cf3f2e121e9..222a77c4898bf705f98f98fba841bbfff5e852cc 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -25,6 +25,7 @@ py_test( srcs = ["predict_test.py"], data = ["data/period_trend.csv"], srcs_version = "PY2AND3", + tags = ["notsan"], # b/67513579 deps = [ ":predict", "//tensorflow/python:client_testlib", @@ -96,6 +97,7 @@ py_test( timeout = "long", # Moderate but for asan srcs = ["lstm_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [":lstm"], ) diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index da583a2ba0c063a55dc149a26b2c6c9d771e1a2a..7491b1b2d2718d3d368f46b641da1aa40d7ca5c9 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -105,6 +105,7 @@ py_test( tags = [ "no_pip_gpu", # b/63391119 "nomsan", # Takes too long to run. + "notsan", # b/67865658 ], deps = [ ":ar_model", @@ -371,6 +372,7 @@ py_test( "ar_model_test.py", ], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":ar_model", ":estimators", diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index 3308f620d9624b38479ca9010ab969e75483b17e..3738dfa154d4f39b9562446972443ed88f3fbe8b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -20,8 +20,8 @@ from __future__ import print_function from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import feature_keys -from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib +from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries import state_management from tensorflow.contrib.timeseries.python.timeseries.state_space_models import state_space_model from tensorflow.contrib.timeseries.python.timeseries.state_space_models import structural_ensemble @@ -59,9 +59,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator): if optimizer is None: optimizer = train.AdamOptimizer(0.02) self._model = model - model_fn = ts_head_lib.time_series_regression_head( + ts_regression_head = ts_head_lib.time_series_regression_head( model, state_manager, optimizer, - input_statistics_generator=input_statistics_generator).create_estimator_spec + input_statistics_generator=input_statistics_generator) + model_fn = ts_regression_head.create_estimator_spec super(TimeSeriesRegressor, self).__init__( model_fn=model_fn, model_dir=model_dir, diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index a8e22566cda086d414231fe81632252ab530ca50..5896fc2a206bc747688b5b012e0f87465592dd8a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -1,3 +1,18 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Timeseries head.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -22,31 +37,34 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest -def time_series_regression_head( - model, state_manager, optimizer, input_statistics_generator=None): +def time_series_regression_head(model, + state_manager, + optimizer, + input_statistics_generator=None): """Creates a `_Head` for time series regression. Args: - weight_column: A string or a `_NumericColumn` created by - `tf.feature_column.numeric_column` defining feature column representing - weights. It is used to down weight or boost examples during training. It - will be multiplied by the loss of the example. - label_dimension: Number of regression labels per example. This is the size - of the last dimension of the labels `Tensor` (typically, this has shape - `[batch_size, label_dimension]`). + model: A model for time series regression. + state_manager: A state manager. + optimizer: An optimizer. + input_statistics_generator: A input statistics generator. Returns: An instance of `_Head` for time series regression. """ - return _TimeSeriesRegressionHead( - model, state_manager, optimizer, input_statistics_generator) + return _TimeSeriesRegressionHead(model, state_manager, optimizer, + input_statistics_generator) class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access """See `time_series_regression_head`.""" - def __init__(self, model, state_manager, optimizer, - input_statistics_generator=None, name=None): + def __init__(self, + model, + state_manager, + optimizer, + input_statistics_generator=None, + name=None): self.model = model self.state_manager = state_manager self.optimizer = optimizer @@ -56,31 +74,33 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc def _train_ops(self, features): """Add training ops to the graph.""" with variable_scope.variable_scope("model"): - model_outputs = self.state_manager.define_loss(self.model, features, - estimator_lib.ModeKeys.TRAIN) + model_outputs = self.state_manager.define_loss( + self.model, features, estimator_lib.ModeKeys.TRAIN) + train_op = optimizers.optimize_loss( - model_outputs.loss, - global_step=variables.get_global_step(), - optimizer=self.optimizer, - # Learning rate is set in the Optimizer object - learning_rate=None) + model_outputs.loss, + global_step=variables.get_global_step(), + optimizer=self.optimizer, + # Learning rate is set in the Optimizer object + learning_rate=None) return estimator_lib.EstimatorSpec( - loss=model_outputs.loss, - mode=estimator_lib.ModeKeys.TRAIN, - train_op=train_op) + loss=model_outputs.loss, + mode=estimator_lib.ModeKeys.TRAIN, + train_op=train_op) - # TODO: suffix summary and metrics keys by `"/" + name` + # TODO(terrytangyuan): suffix summary and metrics keys by `"/" + name` @property def name(self): return self._name - # TOOD: unused for now. Need to decouple `state_manager.define_loss` - # to satisfy the extendable return signature of `_Head.create_loss`. + # TODO(terrytangyuan): unused for now. Need to decouple + # `state_manager.define_loss` to satisfy the extendable return signature of + # `_Head.create_loss`. def create_loss(self, features, mode, logits, labels): """See `_Head`.""" return None - # TODO: check label dimension + # TODO(terrytangyuan): check label dimension @property def logits_dimension(self): return None @@ -88,58 +108,59 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc def _evaluate_ops(self, features): """Add ops for evaluation (aka filtering) to the graph.""" with variable_scope.variable_scope("model"): - model_outputs = self.state_manager.define_loss(self.model, features, - estimator_lib.ModeKeys.EVAL) + model_outputs = self.state_manager.define_loss( + self.model, features, estimator_lib.ModeKeys.EVAL) metrics = {} # Just output in-sample predictions for the last chunk seen for prediction_key, prediction_value in model_outputs.predictions.items(): metrics[prediction_key] = _identity_metric_single(prediction_key, prediction_value) metrics[feature_keys.FilteringResults.TIMES] = _identity_metric_single( - feature_keys.FilteringResults.TIMES, model_outputs.prediction_times) + feature_keys.FilteringResults.TIMES, model_outputs.prediction_times) metrics[feature_keys.FilteringResults.STATE_TUPLE] = ( - _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE, - model_outputs.end_state)) + _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE, + model_outputs.end_state)) return estimator_lib.EstimatorSpec( - loss=model_outputs.loss, - mode=estimator_lib.ModeKeys.EVAL, - eval_metric_ops=metrics, - predictions={}) + loss=model_outputs.loss, + mode=estimator_lib.ModeKeys.EVAL, + eval_metric_ops=metrics, + predictions={}) def _predict_ops(self, features): """Add ops for prediction to the graph.""" with variable_scope.variable_scope("model"): prediction = self.model.predict(features=features) prediction[feature_keys.PredictionResults.TIMES] = features[ - feature_keys.PredictionFeatures.TIMES] + feature_keys.PredictionFeatures.TIMES] return estimator_lib.EstimatorSpec( - predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT) + predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT) def _serving_ops(self, features): """Add ops for serving to the graph.""" with variable_scope.variable_scope("model"): prediction_outputs = self.model.predict(features=features) with variable_scope.variable_scope("model", reuse=True): - filtering_outputs = self.state_manager.define_loss(self.model, features, - estimator_lib.ModeKeys.EVAL) + filtering_outputs = self.state_manager.define_loss( + self.model, features, estimator_lib.ModeKeys.EVAL) + return estimator_lib.EstimatorSpec( - mode=estimator_lib.ModeKeys.PREDICT, - export_outputs={ - feature_keys.SavedModelLabels.PREDICT: - export_lib.PredictOutput(prediction_outputs), - feature_keys.SavedModelLabels.FILTER: - export_lib.PredictOutput( - state_to_dictionary(filtering_outputs.end_state)) - }, - # Likely unused, but it is necessary to return `predictions` to satisfy - # the Estimator's error checking. - predictions={}) + mode=estimator_lib.ModeKeys.PREDICT, + export_outputs={ + feature_keys.SavedModelLabels.PREDICT: + export_lib.PredictOutput(prediction_outputs), + feature_keys.SavedModelLabels.FILTER: + export_lib.PredictOutput( + state_to_dictionary(filtering_outputs.end_state)) + }, + # Likely unused, but it is necessary to return `predictions` to satisfy + # the Estimator's error checking. + predictions={}) def _convert_feature_to_tensor(self, name, value): """Casts features to the correct dtype based on their name.""" if name in [ - feature_keys.TrainEvalFeatures.TIMES, - feature_keys.PredictionFeatures.TIMES + feature_keys.TrainEvalFeatures.TIMES, + feature_keys.PredictionFeatures.TIMES ]: return math_ops.cast(value, dtypes.int64) if name == feature_keys.TrainEvalFeatures.VALUES: @@ -164,39 +185,45 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc del features[key] numbered_state.sort(key=lambda number, *_: number) features[feature_keys.State.STATE_TUPLE] = nest.pack_sequence_as( - structure=self.model.get_start_state(), - flat_sequence=[tensor for _, _, tensor in numbered_state]) + structure=self.model.get_start_state(), + flat_sequence=[tensor for _, _, tensor in numbered_state]) return features, True def create_estimator_spec(self, features, mode, labels=None): """Performs basic error checking and returns an EstimatorSpec.""" with ops.name_scope("head"): if labels: - raise ValueError("The model received a `labels` dictionary, which is not" - " supported. Pass '{}' and '{}' as features.".format( - feature_keys.TrainEvalFeatures.TIMES, - feature_keys.TrainEvalFeatures.VALUES)) + raise ValueError( + "The model received a `labels` dictionary, which is " + "not supported. Pass '{}' and '{}' as " + "features.".format(feature_keys.TrainEvalFeatures.TIMES, + feature_keys.TrainEvalFeatures.VALUES)) del labels - features = {name: self._convert_feature_to_tensor(name=name, value=value) - for name, value in features.items()} + features = { + name: self._convert_feature_to_tensor(name=name, value=value) + for name, value in features.items() + } if self.input_statistics_generator is not None: input_statistics = self.input_statistics_generator.initialize_graph( - features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN)) + features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN)) else: input_statistics = None self.model.initialize_graph(input_statistics=input_statistics) - # _gather_state requires the model to have its graph initialized (so it has - # access to the structure of the model's state) + + # _gather_state requires the model to have its graph initialized (so it + # has access to the structure of the model's state) features, passed_flat_state = self._gather_state(features) - if (mode == estimator_lib.ModeKeys.TRAIN - or mode == estimator_lib.ModeKeys.EVAL): + if (mode == estimator_lib.ModeKeys.TRAIN or + mode == estimator_lib.ModeKeys.EVAL): _check_train_eval_features(features, self.model) elif mode == estimator_lib.ModeKeys.PREDICT: _check_predict_features(features) else: raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode)) + self.state_manager.initialize_graph( - model=self.model, input_statistics=input_statistics) + model=self.model, input_statistics=input_statistics) + if mode == estimator_lib.ModeKeys.TRAIN: return self._train_ops(features) elif mode == estimator_lib.ModeKeys.EVAL: @@ -210,8 +237,10 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc return self._serving_ops(features) -def _check_feature_shapes_compatible_with( - features, compatible_with_name, compatible_with_value, ignore=None): +def _check_feature_shapes_compatible_with(features, + compatible_with_name, + compatible_with_value, + ignore=None): """Checks all features are compatible with the given time-like feature.""" if ignore is None: ignore = set() @@ -223,77 +252,77 @@ def _check_feature_shapes_compatible_with( continue if feature_shape.ndims < 2: raise ValueError( - ("Features must have shape (batch dimension, window size, ...) " - "(got rank {} for feature '{}')").format( - feature_shape.ndims, name)) + ("Features must have shape (batch dimension, window size, ...) " + "(got rank {} for feature '{}')").format(feature_shape.ndims, name)) if not feature_shape[:2].is_compatible_with( - compatible_with_value.get_shape()): + compatible_with_value.get_shape()): raise ValueError( - ("Features must have shape (batch dimension, window size, ...) " - "where batch dimension and window size match the " - "'{times_feature}' feature (got shape {feature_shape} for " - "feature '{feature_name}' but shape {times_shape} for feature " - "'{times_feature}')").format( - times_feature=compatible_with_name, - feature_shape=feature_shape, - feature_name=name, - times_shape=compatible_with_value.get_shape())) + ("Features must have shape (batch dimension, window size, ...) " + "where batch dimension and window size match the " + "'{times_feature}' feature (got shape {feature_shape} for " + "feature '{feature_name}' but shape {times_shape} for feature " + "'{times_feature}')").format( + times_feature=compatible_with_name, + feature_shape=feature_shape, + feature_name=name, + times_shape=compatible_with_value.get_shape())) def _check_predict_features(features): """Raises errors if features are not suitable for prediction.""" if feature_keys.PredictionFeatures.TIMES not in features: raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.TIMES)) + feature_keys.PredictionFeatures.TIMES)) if feature_keys.PredictionFeatures.STATE_TUPLE not in features: raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.STATE_TUPLE)) + feature_keys.PredictionFeatures.STATE_TUPLE)) times_feature = features[feature_keys.PredictionFeatures.TIMES] if not times_feature.get_shape().is_compatible_with([None, None]): raise ValueError( - ("Expected shape (batch dimension, window size) for feature '{}' " - "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, - times_feature.get_shape())) + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, + times_feature.get_shape())) _check_feature_shapes_compatible_with( - features=features, - compatible_with_name=feature_keys.PredictionFeatures.TIMES, - compatible_with_value=times_feature, - ignore=set([ - feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes - ])) + features=features, + compatible_with_name=feature_keys.PredictionFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes + ])) def _check_train_eval_features(features, model): """Raise errors if features are not suitable for training/evaluation.""" if feature_keys.TrainEvalFeatures.TIMES not in features: raise ValueError("Expected a '{}' feature for training/evaluation.".format( - feature_keys.TrainEvalFeatures.TIMES)) + feature_keys.TrainEvalFeatures.TIMES)) if feature_keys.TrainEvalFeatures.VALUES not in features: raise ValueError("Expected a '{}' feature for training/evaluation.".format( - feature_keys.TrainEvalFeatures.VALUES)) + feature_keys.TrainEvalFeatures.VALUES)) times_feature = features[feature_keys.TrainEvalFeatures.TIMES] if not times_feature.get_shape().is_compatible_with([None, None]): raise ValueError( - ("Expected shape (batch dimension, window size) for feature '{}' " - "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES, - times_feature.get_shape())) + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES, + times_feature.get_shape())) values_feature = features[feature_keys.TrainEvalFeatures.VALUES] if not values_feature.get_shape().is_compatible_with( - [None, None, model.num_features]): + [None, None, model.num_features]): raise ValueError( - ("Expected shape (batch dimension, window size, {num_features}) " - "for feature '{feature_name}', since the model was configured " - "with num_features={num_features} (got shape {got_shape})").format( - num_features=model.num_features, - feature_name=feature_keys.TrainEvalFeatures.VALUES, - got_shape=times_feature.get_shape())) + ("Expected shape (batch dimension, window size, {num_features}) " + "for feature '{feature_name}', since the model was configured " + "with num_features={num_features} (got shape {got_shape})").format( + num_features=model.num_features, + feature_name=feature_keys.TrainEvalFeatures.VALUES, + got_shape=times_feature.get_shape())) _check_feature_shapes_compatible_with( - features=features, - compatible_with_name=feature_keys.TrainEvalFeatures.TIMES, - compatible_with_value=times_feature, - ignore=set([ - feature_keys.State.STATE_TUPLE # Model-dependent shapes - ])) + features=features, + compatible_with_name=feature_keys.TrainEvalFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + feature_keys.State.STATE_TUPLE # Model-dependent shapes + ])) + def _identity_metric_single(name, input_tensor): """A metric which takes on its last updated value. @@ -311,12 +340,12 @@ def _identity_metric_single(name, input_tensor): A tuple of (value, update_op). """ metric_variable = variable_scope.variable( - name="{}_identity_metric".format(name), - initial_value=array_ops.zeros([], dtype=input_tensor.dtype), - collections=[ops.GraphKeys.LOCAL_VARIABLES], - validate_shape=False) - update_op = state_ops.assign(metric_variable, input_tensor, - validate_shape=False) + name="{}_identity_metric".format(name), + initial_value=array_ops.zeros([], dtype=input_tensor.dtype), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + validate_shape=False) + update_op = state_ops.assign( + metric_variable, input_tensor, validate_shape=False) # This shape will be correct once the first update runs (but may be # incomplete, so is not helpful for initializing the variable). metric_variable.set_shape(input_tensor.get_shape()) @@ -329,13 +358,13 @@ def _identity_metric_nested(name, input_tensors): value_tensors = [] for tensor_number, tensor in enumerate(nest.flatten(input_tensors)): value_tensor, update_op = _identity_metric_single( - name="{}_{}".format(name, tensor_number), - input_tensor=tensor) + name="{}_{}".format(name, tensor_number), input_tensor=tensor) update_ops.append(update_op) value_tensors.append(value_tensor) return (nest.pack_sequence_as(input_tensors, value_tensors), control_flow_ops.group(*update_ops)) + def state_to_dictionary(state_tuple): """Flatten model state into a dictionary with string keys.""" flattened = {} @@ -344,4 +373,3 @@ def state_to_dictionary(state_tuple): state_number) flattened[prefixed_state_name] = state_value return flattened - diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index 7ebcebfe1b156a6c0cc86fa1ded55e4a645d291f..3415061cfd87358cccaf36dcb301fb36986bbde6 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.timeseries.python.timeseries import feature_keys -from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib +from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import state_management from tensorflow.python.estimator import estimator_lib diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index e753fe7a5140028f238c2ff3754b1d7335ae8eb2..c89596734c738467c58e845328e396c3f2eb999a 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -30,11 +30,24 @@ cc_library( ], ) +py_library( + name = "tpu_test_util", + srcs = [ + "python/tpu/test_util.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":tpu_lib", + ":tpu_py", + ], +) + py_library( name = "tpu_estimator", srcs = [ "python/tpu/tpu_config.py", "python/tpu/tpu_estimator.py", + "python/tpu/util.py", ], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index a40e2a7898a304c21a60929b30719f3132aec0f0..b40dac471708793d5a033279e2d2f4b4a0dac480 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -22,6 +22,11 @@ namespace tensorflow { using shape_inference::InferenceContext; using shape_inference::ShapeHandle; +REGISTER_OP("TPUReplicateMetadata") + .Attr("num_replicas: int >= 0") + .Attr("global_tpu_id: list(int) = []") + .SetShapeFn(shape_inference::UnknownShape); + REGISTER_OP("TPUReplicatedInput") .Input("inputs: N * T") .Output("output: T") diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 8d3344fac36be24a692f141eee140312d988a932..33e47f674d798f622fb08121dabb67d7f45af15b 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -21,9 +21,11 @@ from __future__ import print_function import platform +from tensorflow.python.framework import ops if platform.system() != "Windows": # pylint: disable=wildcard-import,unused-import,g-import-not-at-top + from tensorflow.contrib.tpu.ops import gen_tpu_ops from tensorflow.contrib.tpu.ops.gen_tpu_ops import * from tensorflow.contrib.util import loader @@ -32,6 +34,12 @@ if platform.system() != "Windows": _tpu_ops = loader.load_op_library( resource_loader.get_path_to_datafile("_tpu_ops.so")) + + @ops.RegisterGradient("CrossReplicaSum") + def _cross_replica_sum_grad(op, grad): + del op # Unused + # The gradient of a cross replica sum is also a cross-replica sum. + return gen_tpu_ops.cross_replica_sum(grad) else: # We have already built the appropriate libraries into the binary via CMake # if we have built contrib, so we don't need this diff --git a/tensorflow/contrib/tpu/python/tpu/test_util.py b/tensorflow/contrib/tpu/python/tpu/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f30c27f1298e2389fe0daefdd4eece5a03a6976c --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/test_util.py @@ -0,0 +1,153 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""Utilities to ease testing on TPU devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import tpu + +from tensorflow.python.client import session +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import variables + + +def has_tpu(): + """Check if a TPU device is available. + + Device enumeration via `device_lib` currently fails for TPU systems. + (http://b/68333779). To work around this, we determine the existence of a + TPU by a successful call to `initialize_system`. + + Returns: + boolean, True if a TPU device is available, otherwise False. + """ + def _check(): + with session.Session() as sess: + sess.run(tpu.initialize_system()) + sess.run(tpu.shutdown_system()) + + try: + _check() + return True + except errors.OpError as _: + return False + + +def _available_devices(): + devices = ["cpu"] + if not test_util.gpu_device_name(): + devices.append("gpu") + + if has_tpu(): + devices.append("tpu") + + return tuple(devices) + + +class TPUTestCase(test_util.TensorFlowTestCase): + """Adds helpers for testing on TPU devices to `TensorFlowTestCase`. + + Example usage: + + ``` + def model_fn(features): + return tf.reduce_sum(features * 2) + + class ModelTests(test_util.TPUTestCase): + def test_sum(self): + v = np.random.randn(10, 10).astype("float32") + self.assert_device_output(model_fn, [v], (v*2).sum(), + devices=("cpu", "tpu")) + ``` + """ + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(TPUTestCase, self).__init__(methodName) + self._available_devices = _available_devices() + + def run_on_device(self, model_fn, model_inputs, device): + """Runs `model_fn` on the given device. + + Raises an exception if no such device is available. `model_fn` should + return one or more tensors as a list or tuple. + + Args: + model_fn: Function returning one or more tensors. + model_inputs: An iterable of Numpy arrays or scalars. + These will be passed as arguments to `model_fn`. + device: Device to run on. One of ("tpu", "gpu", "cpu"). + + Returns: + Output from the model function. + """ + def _make_placeholders(): + return dict( + [(gen_array_ops.placeholder_with_default(v, v.shape), v) + for v in model_inputs]) + + if device == "tpu": + with self.test_session(graph=ops.Graph()) as sess: + placeholders = _make_placeholders() + tpu_computation = tpu.rewrite(model_fn, placeholders.keys()) + sess.run(tpu.initialize_system()) + sess.run(variables.global_variables_initializer()) + result = sess.run(tpu_computation, placeholders) + sess.run(tpu.shutdown_system()) + # TODO(b/36891278): supports non-flat returns lists in tpu.rewrite(). + if len(result) == 1: + return result[0] + return result + elif device == "gpu": + with self.test_session(graph=ops.Graph(), use_gpu=True) as sess: + placeholders = _make_placeholders() + sess.run(variables.global_variables_initializer()) + return sess.run(model_fn(placeholders.keys()), placeholders) + elif device == "cpu": + # TODO(power) -- will this interact poorly with cached GPU sessions? + with self.test_session(graph=ops.Graph(), use_gpu=False) as sess: + placeholders = _make_placeholders() + sess.run(variables.global_variables_initializer()) + return sess.run(model_fn(placeholders.keys()), placeholders) + + def _compare_values(self, actual_outputs, expected_outputs): + if isinstance(expected_outputs, (list, tuple)): + for a, b in zip(actual_outputs, expected_outputs): + self.assertAllCloseAccordingToType(a, b) + else: + self.assertAllCloseAccordingToType(actual_outputs, expected_outputs) + + def assert_device_output(self, model_fn, model_inputs, expected_outputs, + devices=("cpu", "gpu", "tpu")): + """Run `model_fn` on the given devices. + + Results are compared via `assertAllCloseAccordingToType`. + + Args: + model_fn: Function returning one or more tensors + model_inputs: Numpy arrays or scalars passed as arguments to model_fn + expected_outputs: Numpy arrays or scalars to compare against. + devices: Set of devices to run on. If a device is not available, tests + will be skipped for that device. + """ + devices = set(devices).intersection(self._available_devices) + + for device in devices: + device_out = self.run_on_device(model_fn, model_inputs, device=device) + self._compare_values(device_out, expected_outputs) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index f6800e3e246dc5f6242a7bf127f6397fedf92b9f..338a4304f3272f3486c88e6e2aeb90fec15e4f58 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -105,9 +105,8 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): """A ControlFlowContext for nodes inside a TPU computation. The primary role of TPUReplicateContext is to mark operators inside a - tpu.replicate() computation with attributes: - * _tpu_replicate=XYZ, where XYZ is a unique name, and - * _tpu_num_replicas=k, where k is the number of replicas. + tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ + is a unique name. We use a ControlFlowContext to perform the annotation since it integrates with Tensorflow constructs like ResourceVariables. For example, @@ -116,11 +115,9 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): to build the variable's definition outside the replicated computation. """ - def __init__(self, name, num_replicas, global_tpu_id=None): + def __init__(self, name): control_flow_ops.ControlFlowContext.__init__(self) self._name = name - self._num_replicas = num_replicas - self._global_tpu_id = [] if global_tpu_id is None else global_tpu_id def AddOp(self, op): self._AddOpInternal(op) @@ -135,8 +132,6 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): if "_tpu_replicate" in op.node_def.attr: raise ValueError("TPU computations cannot be nested") op.node_def.attr["_tpu_replicate"].s = self._name - op.node_def.attr["_tpu_num_replicas"].i = self._num_replicas - op.node_def.attr["_tpu_global_id"].list.i.extend(self._global_tpu_id) op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) @@ -151,6 +146,14 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): if self._outer_context: self._outer_context.AddInnerOp(op) + @property + def grad_state(self): + # Define the gradient loop state associated with the TPUReplicateContext to + # be None as the TPUReplicateContext does not get nested nor does the + # grad_state outside the TPUReplicateContext affect the graph inside so the + # grad_state should be as if this is the top-level gradient state. + return None + def replicate(computation, inputs=None, @@ -243,14 +246,15 @@ def replicate(computation, computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) - context = TPUReplicateContext( - name=graph.unique_name("cluster"), - num_replicas=num_replicas, - global_tpu_id=global_tpu_id) + context = TPUReplicateContext(name=graph.unique_name("cluster")) try: context.Enter() - with tpu_function.tpu_shard_context(num_replicas): + metadata = tpu_ops.tpu_replicate_metadata( + num_replicas=num_replicas, global_tpu_id=global_tpu_id) + + with tpu_function.tpu_shard_context( + num_replicas), ops.control_dependencies([metadata]): # The EncapsulateTPUComputations rewrite needs to identify the # replicated arguments inside each computation. Adds identity operators diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 02135bfe40e72860d474f441b6cc57430d4e0fca..3965c087a18dc18298703fad9b1dda9c85c56271 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -21,12 +21,16 @@ from __future__ import print_function import collections +from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.python.estimator import run_config as run_config_lib class TPUConfig( collections.namedtuple('TPUConfig', [ - 'iterations_per_loop', 'num_shards', 'per_host_input_for_training' + 'iterations_per_loop', + 'num_shards', + 'per_host_input_for_training', + 'tpu_job_name', ])): """TPU related configuration required by `TPUEstimator`. @@ -36,31 +40,62 @@ class TPUConfig( global step is increased `iterations_per_loop` times in one `Session.run`. It is recommended to be set as number of global steps for next checkpoint. num_shards: The number of TPU shards in the system. - per_host_input_for_training: If `True`, `input_fn` is invoked per host - rather than per shard. Note: This behavior is going to be default as - `True` soon, so this flag will be removed after that. Also note that this - only works for single-host TPU training now. + per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host + rather than Per-Core. With Per-Host input pipeline deployment, `input_fn` + is invoked once on each host. To be precise, with a global batch size + `train_batch_size` in `TPUEstimator` constructor, the batch size for each + shard is `train_batch_size` // #hosts. With Per-Core input pipeline + deployment, the shard batch size is `train_batch_size` // #cores. Note + that this only works for single-host TPU training now (tracked in + b/67051042). For multi-host, please use Per-Core, i.e., `False` for + `per_host_input_for_training`. + tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred + within TPUEstimator, however when using ClusterSpec propagation in more + esoteric cluster configurations, you may need to specify the job name as a + string. """ def __new__(cls, iterations_per_loop=2, num_shards=2, - per_host_input_for_training=False): + per_host_input_for_training=True, + tpu_job_name=None): + + # Check iterations_per_loop. + util_lib.check_positive_integer(iterations_per_loop, + 'TPUConfig iterations_per_loop') + + # Check num_shards. + util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') return super(TPUConfig, cls).__new__( cls, iterations_per_loop=iterations_per_loop, num_shards=num_shards, - per_host_input_for_training=per_host_input_for_training) + per_host_input_for_training=per_host_input_for_training, + tpu_job_name=tpu_job_name) class RunConfig(run_config_lib.RunConfig): """RunConfig with TPU support.""" - def __init__(self, tpu_config=None, evaluation_master='', master='', + def __init__(self, tpu_config=None, evaluation_master=None, master='', **kwargs): + """Constructs a RunConfig. + + Args: + tpu_config: the TPUConfig that specifies TPU-specific configuration. + evaluation_master: a string. The address of the master to use for eval. + Defaults to master if not set. + master: a string. The address of the master to use for training. + tf_random_seed: an int. Sets the TensorFlow random seed. Defaults to None, + which initializes it randomly based on the environment. + """ super(RunConfig, self).__init__(**kwargs) self._tpu_config = tpu_config or TPUConfig() - self._evaluation_master = evaluation_master + if evaluation_master is None: + self._evaluation_master = master + else: + self._evaluation_master = evaluation_master self._master = master @property diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index b5001d596b8cb34b0cfd32df0864e466ab7d86b6..5a3b8314291951b5dfce091dccb0dc9e5f7af3b5 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +from contextlib import contextmanager import copy import threading import six @@ -31,12 +32,14 @@ from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_feed from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -56,12 +59,15 @@ from tensorflow.python.training import training_util _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. -_DEFAULT_NAME_SCOPE = 'tpu_estimator' +_TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] +# TODO(b/65703635): Flip the value and remove all dead code. +_WRAP_INPUT_FN_INTO_WHILE_LOOP = False + def _create_global_step(graph): graph = graph or ops.get_default_graph() @@ -80,17 +86,25 @@ def _create_global_step(graph): ops.GraphKeys.GLOBAL_STEP]) -def _create_iterations_per_loop(): - with variable_scope.variable_scope(_DEFAULT_NAME_SCOPE, - reuse=variable_scope.AUTO_REUSE): - return variable_scope.get_variable( - _ITERATIONS_PER_LOOP_VAR, - initializer=init_ops.zeros_initializer(), - shape=[], - dtype=dtypes.int32, - trainable=False, - collections=[], - use_resource=True) +def _create_or_get_iterations_per_loop(): + graph = ops.get_default_graph() + iter_vars = graph.get_collection(_TPU_ESTIMATOR) + if len(iter_vars) == 1: + return iter_vars[0] + elif len(iter_vars) > 1: + raise RuntimeError('Multiple iterations_per_loop_var in collection.') + + with ops.colocate_with(training_util.get_global_step()): + with variable_scope.variable_scope(_TPU_ESTIMATOR, + reuse=variable_scope.AUTO_REUSE): + return variable_scope.get_variable( + _ITERATIONS_PER_LOOP_VAR, + initializer=init_ops.zeros_initializer(), + shape=[], + dtype=dtypes.int32, + trainable=False, + collections=[_TPU_ESTIMATOR], + use_resource=True) def _sync_variables_ops(): @@ -121,26 +135,214 @@ def _increase_eval_step_op(iterations_per_loop): use_locking=True) -def _tpu_job(run_config, mode): - # The tpu job is determined by the run_config. Right now, this method is - # required as tpu_config is not part of the RunConfig. - master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL - else run_config.master) - return None if master in ['', 'local'] else 'tpu_worker' +_DEFAULT_JOB_NAME = 'tpu_worker' +_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' +_LOCAL_MASTERS = ('', 'local') + + +class _TPUContext(object): + """A context holds immutable states of TPU computation. + + This immutable object holds TPUEstimator config, train/eval batch size, and + `TPUEstimator.use_tpu`, which is expected to be passed around. It also + provides utility functions, basded on the current state, to determine other + information commonly required by TPU computation, such as TPU device names, + TPU hosts, shard batch size, etc. + + N.B. As `mode` is not immutable state in Estimator, but essential to + distinguish between TPU training and evaluation, a common usage for + _TPUContext with `mode` is as follows: + ``` + with _ctx.with_mode(mode) as ctx: + if ctx.is_running_on_cpu(): + ... + ``` + """ + + def __init__(self, config, train_batch_size, eval_batch_size, use_tpu): + self._config = config + self._train_batch_size = train_batch_size + self._eval_batch_size = eval_batch_size + self._use_tpu = use_tpu + self._num_shards_or_none = self._config.tpu_config.num_shards + self._mode = None + + def _assert_mode(self): + if self._mode is None: + raise RuntimeError( + '`mode` needs to be set via contextmanager `with_mode`.') + return self._mode + + @property + def num_of_cores_per_host(self): + num_cores = self.num_cores + return min(num_cores, 8) + + @contextmanager + def with_mode(self, mode): + new_ctx = copy.copy(self) # Shallow copy is enough. + new_ctx._mode = mode # pylint: disable=protected-access + yield new_ctx + + @property + def mode(self): + return self._assert_mode() + + @property + def num_cores(self): + # TODO(xiejw): Adds lazy num_shards initialization. + return self._num_shards_or_none + + @property + def num_hosts(self): + return self.num_cores // self.num_of_cores_per_host + + @property + def config(self): + return self._config + + def is_input_sharded_per_core(self): + """Return true if input_fn is invoked per-core (other than per-host).""" + self._assert_mode() + return (self._mode == model_fn_lib.ModeKeys.TRAIN and + not self._config.tpu_config.per_host_input_for_training) + + def is_running_on_cpu(self): + """Determines whether the input_fn and model_fn should be invoked on CPU.""" + mode = self._assert_mode() + return ((not self._use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT or + (mode == model_fn_lib.ModeKeys.EVAL and + self._eval_batch_size is None)) + + @property + def batch_size_for_input_fn(self): + """Returns the shard batch size for `input_fn`.""" + mode = self._assert_mode() + # Special case for eval. + if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: + return None + if self.is_running_on_cpu(): + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size + if mode == model_fn_lib.ModeKeys.EVAL: + return self._eval_batch_size + return None + + global_batch_size = (self._train_batch_size if + mode == model_fn_lib.ModeKeys.TRAIN + else self._eval_batch_size) + # On TPU + return (global_batch_size // self.num_cores + if self.is_input_sharded_per_core() else global_batch_size) + + @property + def batch_size_for_model_fn(self): + """Returns the shard batch size for `model_fn`.""" + mode = self._assert_mode() + # Special case for eval. + if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: + return None + if self.is_running_on_cpu(): + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size + if mode == model_fn_lib.ModeKeys.EVAL: + return self._eval_batch_size + return None + + # On TPU. always sharded per core. + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size // self.num_cores + else: + return self._eval_batch_size // self.num_cores + + @property + def master_job(self): + """Returns the job name to use to place TPU computations on. + + Returns: + A string containing the job name, or None if no job should be specified. + + Raises: + ValueError: If the user needs to specify a tpu_job_name, because we are + unable to infer the job name automatically, or if the user-specified job + names are inappropriate. + """ + run_config = self._config + # If the user specifies the tpu_job_name, use that. + if run_config.tpu_config.tpu_job_name: + return run_config.tpu_config.tpu_job_name + + # The tpu job is determined by the run_config. Right now, this method is + # required as tpu_config is not part of the RunConfig. + mode = self._assert_mode() + master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL + else run_config.master) + if master in _LOCAL_MASTERS: + return None + + if (not run_config.session_config or + not run_config.session_config.cluster_def.job): + return _DEFAULT_JOB_NAME + cluster_def = run_config.session_config.cluster_def + job_names = set([job.name for job in cluster_def.job]) + if _DEFAULT_JOB_NAME in job_names: + # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. + raise ValueError('Currently, tpu_worker is not an allowed job name.') + if len(job_names) == 1: + return cluster_def.job[0].name + if len(job_names) == 2: + if _DEFAULT_COORDINATOR_JOB_NAME in job_names: + job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) + return job_names.pop() + # TODO(b/67716447): Include more sophisticated heuristics. + raise ValueError( + 'Could not infer TPU job name. Please specify a tpu_job_name as part ' + 'of your TPUConfig.') + + @property + def tpu_host_placement_function(self): + """Returns the TPU host place function.""" + master = self.master_job + def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name + assert _sentinal is None + if core_id is not None and host_id is not None: + raise RuntimeError( + 'core_id and host_id can have only one non-None value.') + + if master is None: + return '/replica:0/task:0/device:CPU:0' + else: + # This assumes that if using more than 8 shards, + # the job configuration varies 'task'. + if core_id is not None: + host_id = core_id / 8 + return '/job:%s/task:%d/device:CPU:0' % (master, host_id) + return _placement_function + + @property + def tpu_device_placement_function(self): + master = self.master_job + job_device = '' if master is None else ('/job:%s' % master) + def _placement_function(i): + return '%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8) + return _placement_function + @property + def tpu_ordinal_function(self): + """Returns the TPU ordinal fn.""" + def _tpu_ordinal_function(index): + """Return the TPU ordinal associated with a shard. -def _is_running_on_cpu(use_tpu, mode, eval_batch_size): - """Determines whether the input_fn and model_fn should be invoked on CPU.""" - return ((not use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT or - (mode == model_fn_lib.ModeKeys.EVAL and eval_batch_size is None)) + Required because the enqueue ops are placed on CPU. + Args: + index: the shard index -def _per_shard_batch_size(global_batch_size, run_config, use_tpu): - """Returns the batch size for each shard.""" - if use_tpu: - return global_batch_size // run_config.tpu_config.num_shards - else: - return global_batch_size + Returns: + The ordinal of the TPU device the shard's infeed should be placed on. + """ + return index % 8 + return _tpu_ordinal_function class _SIGNAL(object): @@ -268,17 +470,30 @@ class _InfeedThreadController(_InfeedOutfeedThreadBaseController): def _input_thread_fn_for_loading(self, session, enqueue_ops): count = 0 - while True: - signal = self._signal_queue.get() - if signal == _SIGNAL.STOP: - logging.info('Stop Infeed input thread.') - return - - iterations = signal - for i in range(iterations): - logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) - session.run(enqueue_ops) - count += 1 + try: + while True: + signal = self._signal_queue.get() + if signal == _SIGNAL.STOP: + logging.info('Stop Infeed input thread.') + return + + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + # Enqueue batches for next loop. + session.run(enqueue_ops) + else: + iterations = signal + for i in range(iterations): + logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) + session.run(enqueue_ops) + count += 1 + + except Exception: # pylint: disable=broad-except + logging.error( + 'Failed running infeed, closing session.\n' + 'You may see an exception from your main session after this.', + exc_info=1 + ) + session.close() def join(self): logging.info('Waiting for Infeed Thread to exit.') @@ -294,17 +509,16 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): dequeue. """ - def __init__(self, run_config, mode, enqueue_fn, dequeue_ops=None): - self._tpu_job = _tpu_job(run_config, mode) - self._enqueue_fn = enqueue_fn + def __init__(self, ctx, enqueue_ops, dequeue_ops=None): + self._master_job = ctx.master_job + self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops def begin(self): - self._enqueue_ops = self._enqueue_fn() - self._iterations_per_loop_var = _create_iterations_per_loop() - logging.info('TPU job name %s', self._tpu_job) - self._init_op = [tpu.initialize_system(job=self._tpu_job)] - self._finalize_op = [tpu.shutdown_system(job=self._tpu_job)] + logging.info('TPU job name %s', self._master_job) + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() + self._init_op = [tpu.initialize_system(job=self._master_job)] + self._finalize_op = [tpu.shutdown_system(job=self._master_job)] def after_create_session(self, session, coord): logging.info('Init TPU system') @@ -326,6 +540,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): iterations = run_context.session.run(self._iterations_per_loop_var) self._infeed_thd_controller.send_next_batch_signal(iterations) if self._dequeue_ops is not None: + # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop. logging.info('Dequeue next batch of data from outfeed.') self._outfeed_thd_controller.send_next_batch_signal(iterations) @@ -387,7 +602,7 @@ class _TPUStopAtStepHook(session_run_hook.SessionRunHook): if self._global_step_tensor is None: raise RuntimeError('Global step should be created.') - self._iterations_per_loop_var = _create_iterations_per_loop() + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() def after_create_session(self, session, coord): global_step = session.run(self._global_step_tensor) @@ -422,360 +637,288 @@ class _SetEvalIterationsHook(session_run_hook.SessionRunHook): self._num_steps = num_steps def begin(self): - self._iterations_per_loop_var = _create_iterations_per_loop() + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() def after_create_session(self, session, coord): self._iterations_per_loop_var.load(self._num_steps, session=session) -class _PerShardOutput(object): - """Wraps input_fn's outputs into per-shard outputs. - - Used so that the model_fn can distinguish between sharded input and unsharded - inputs (e.g., for export_savedmodel()). - """ - - def __init__(self, output): - self.output = output - - def as_list(self): - return self.output - +def generate_per_core_enqueue_ops_fn_for_host( + ctx, input_fn, inputs_structure_recorder): + """Generates infeed enqueue ops for per-core input_fn on a single host.""" + infeed_queue_holder = {'instance': None} + + def enqueue_ops_fn(): + """A fn returns enqueue_ops.""" + num_cores_per_host = ctx.num_of_cores_per_host + per_host_sharded_inputs = [] + for core_ordinal in range(num_cores_per_host): + with ops.name_scope('ordinal_%d' % (core_ordinal)): + inputs = input_fn() + if isinstance(inputs, tuple): + features, labels = inputs + else: + features, labels = inputs, None -class _InputsHolder(object): - """A inputs holder holds the `features` and `labels' for TPU system. + inputs_structure_recorder.validate_and_record_structure( + features, labels) + flattened_inputs = ( + inputs_structure_recorder.flatten_features_and_labels( + features, labels)) + per_host_sharded_inputs.append(flattened_inputs) - Model inputs returned by the `input_fn` can have one of the following forms: + infeed_queue = tpu_feed.InfeedQueue( + number_of_tuple_elements=len(per_host_sharded_inputs[0])) + infeed_queue_holder['instance'] = infeed_queue + infeed_queue.set_configuration_from_sharded_input_tensors( + per_host_sharded_inputs) + + per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( + per_host_sharded_inputs, + tpu_ordinal_function=ctx.tpu_ordinal_function) + return per_host_enqueue_ops + return enqueue_ops_fn, (lambda: infeed_queue_holder['instance']) + + +class _InputPipeline(object): + """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. + + `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from + call site. To be precise, based on the configuration in `_TPUContext`, it + invokes `input_fn` for all cores (usually multi-host TPU training) or for one + host (usually for single-host TPU evaluation), and sends all `features` and + `labels` returned by `input_fn` to TPU infeed. For per-core invocation, + `features` and `labels` are piped to infeed directly, one tuple for each + core. For per-host invocation, `features` and `labels` are split at host + (with respect to `batch_axis`) and piped to all cores accordingly. + + In addition, flatten/unflatten are handled by `_InputPipeline` also. Model + inputs returned by the `input_fn` can have one of the following forms: 1. features 2. (features, labels) Internally, form 1 is reformed to `(features, None)` as features and labels are passed separatedly to underlying methods. For TPU training, TPUEstimator - expects multiple `features` and `labels` tuples one for each shard. - - In addition, TPUEstimator allows various different structures for inputs - (namely `features` and `labels`). `features` can be `Tensor` or dict of - string name to `Tensor`, and `labels` could be `None`, `Tensor`, or dict of - string name to `Tensor`. TPU infeed/outfeed library expects flattened tensor - list. So, `features` and `labels` need to be flattened, before infeed enqueue, - and the structure of them needs to be recorded, in order to restore them after - infeed dequeue. - - `_InputsHolder` could hold the `features` and `labels` tuple for all shards - (usually multi-host TPU training) or for one host (usually for single-host TPU - evaluation), records the structure details (including presence, dict or single - tensor, dict names), validates the structure consistency cross all shards, and - encapsulates the flatten/unflatten logic. + may expect multiple `features` and `labels` tuples one for each core. + + TPUEstimator allows various different structures for inputs (namely `features` + and `labels`). `features` can be `Tensor` or dict of string name to `Tensor`, + and `labels` could be `None`, `Tensor`, or dict of string name to `Tensor`. + TPU infeed/outfeed library expects flattened tensor list. So, `features` and + `labels` need to be flattened, before infeed enqueue, and the structure of + them needs to be recorded, in order to restore them after infeed dequeue. """ - def __init__(self, features=None, labels=None, num_shards=None): - """Constructor. - - Args: - features: features for one host or a list of features one for each shard - (must be type `_PerShardOutput`). Once provided, the corresponding - `labels` should be set also and this `_InputsHolder` is frozen to - prevent from future modification. If `None`, it is expected to add - features and labels for each shard by calling `append_tuple` later. - labels: labels for one host or a list of labels one for each shard - (must be type `_PerShardOutput`). - num_shards: Number of shards in the TPU system. Must be provided unless it - can be deduced from `features`. - - Raises: - ValueError: If both `sharded_features` and `num_shards` are `None`. - """ - # Holds the features and labels for all shards. - self._feature_list = [] - self._label_list = [] - - # Holds the structure of inputs - self._feature_names = [] - self._label_names = [] - self._has_labels = False - - # Internal state. - self._initialized = False - self._frozen = False - self._sharded = False - - if features is None: - if num_shards is None: - raise ValueError( - '`features` and `num_shards` cannot be both None') - self._num_shards = num_shards - elif isinstance(features, _PerShardOutput): - self._from_sharded_inputs(features, labels, num_shards) - else: - if num_shards is None: - raise ValueError( - '`num_shards` cannot be None for unsharded features.') - self._from_unsharded_inputs(features, labels, num_shards) - - def _from_unsharded_inputs(self, features, labels, num_shards): - """Initializes the inputs with unsharded features and labels.""" - self._num_shards = num_shards - if labels is not None: - self._has_labels = True - self.append_tuple((features, labels)) - else: - self.append_tuple(features) - - self._sharded = False - self._frozen = True - - def _from_sharded_inputs(self, sharded_features, sharded_labels, num_shards): - """Initializes the inputs with sharded features and labels.""" - if not isinstance(sharded_features, _PerShardOutput): - raise ValueError('`sharded_features` must have type `_PerShardOutput`.') - features = sharded_features.as_list() - - if num_shards is not None and num_shards != len(features): - raise ValueError( - '`num_shards` should be same as the length of sharded_features.') + class InputsStructureRecorder(object): + """The recorder to record inputs structure.""" + + def __init__(self): + # Holds the structure of inputs + self._feature_names = [] + self._label_names = [] + self._has_labels = False + + # Internal state. + self._initialized = False + + def has_labels(self): + return self._has_labels + + def validate_and_record_structure(self, features, labels): + """Validates and records the structure of features` and `labels`.""" + def _extract_key_names(tensor_or_dict): + if tensor_or_dict is None: + return [] + return tensor_or_dict.keys() if isinstance(tensor_or_dict, dict) else [] + + # Extract structure. + has_labels = labels is not None + feature_names = _extract_key_names(features) + label_names = _extract_key_names(labels) + + if self._initialized: + # Verify the structure is same. The following should never happen. + assert feature_names == self._feature_names, 'feature keys mismatched' + assert label_names == self._label_names, 'label keys mismatched' + assert has_labels == self._has_labels, 'label presence mismatched' + else: + # Record structure. + self._initialized = True + self._feature_names = feature_names + self._label_names = label_names + self._has_labels = has_labels + + def flatten_features_and_labels(self, features, labels): + """Flattens the `features` and `labels` to a single tensor list.""" + flattened_inputs = [] + if self._feature_names: + # We need a fixed ordering for enqueueing and dequeueing. + flattened_inputs.extend([features[name] + for name in self._feature_names]) + else: + flattened_inputs.append(features) - self._num_shards = len(features) - if not self._num_shards: - raise ValueError('`sharded_features` should not be empty.') + if labels is not None: + if self._label_names: + # We need a fixed ordering for enqueueing and dequeueing. + flattened_inputs.extend([labels[name] for name in self._label_names]) + else: + flattened_inputs.append(labels) + return flattened_inputs + + def unflatten_features_and_labels(self, flattened_inputs): + """Restores the flattened inputs to original features and labels form. + + Args: + flattened_inputs: Flattened inputs for each shard. + + Returns: + A tuple of (`features`, `labels`), where `labels` could be None. + Each one, if present, should have identical structure (single tensor vs + dict) as the one returned by input_fn. + + Raises: + ValueError: If the number of expected tensors from `flattened_inputs` + mismatches the recorded structure. + """ + expected_num_features = (len(self._feature_names) if self._feature_names + else 1) + if self._has_labels: + expected_num_labels = (len(self._label_names) if self._label_names + else 1) + else: + expected_num_labels = 0 - if sharded_labels is not None: - if not isinstance(sharded_labels, _PerShardOutput): - raise ValueError('sharded_labels` must have type `_PerShardOutput`.') + expected_num_tensors = expected_num_features + expected_num_labels - self._has_labels = True - labels = sharded_labels.as_list() - if self._num_shards != len(labels): + if expected_num_tensors != len(flattened_inputs): raise ValueError( - 'Length of `sharded_features` and `sharded_labels` mismatch.') - - if self._has_labels: - for (f, l) in zip(features, labels): - self.append_tuple((f, l)) - else: - for f in features: - self.append_tuple(f) - - self._sharded = True - self._frozen = True - - def _extract_key_names(self, tensor_or_dict): - if tensor_or_dict is None: - return [] - - return tensor_or_dict.keys() if isinstance(tensor_or_dict, dict) else [] - - def _validate(self, features, labels): - has_labels = labels is not None - feature_names = self._extract_key_names(features) - label_names = self._extract_key_names(labels) - - if self._initialized: - self._sharded = True - # The following should never happen. - assert feature_names == self._feature_names, 'feature keys mismatched' - assert label_names == self._label_names, 'label keys mismatched' - assert has_labels == self._has_labels, 'label presence mismatched' - else: - self._initialized = True - self._feature_names = feature_names - self._label_names = label_names - self._has_labels = has_labels - - @property - def sharded(self): - if not self._frozen: - raise RuntimeError('_InputsHolder has not been frozen yet.') - return self._sharded - - @property - def num_shards(self): - if not self._frozen: - raise RuntimeError('_InputsHolder has not been frozen yet.') - return self._num_shards - - def append_tuple(self, inputs): - """Appends `inputs` for one shard into holder. - - Args: - inputs: The return from `input_fn`, which could be features or tuple of - (features, labels). After the first `inputs` appended into - `_InputsHolder`, the structure of `features` and `labels is recorded. - Any future invocation should provide the `inputs` with same structure. - - Raises: - RuntimeError: If the internal data has been frozen already. - """ - if self._frozen: - raise RuntimeError('InputsHolder has frozen, which cannot be mutated.') - - # input_fn may return either features or (features, labels) - if isinstance(inputs, tuple): - features, labels = inputs - else: - features, labels = inputs, None - - self._validate(features, labels) - - self._feature_list.append(features) - if labels is not None: - self._label_list.append(labels) - - def as_features_and_labels_tuple(self): - """Returns features and labels as grouped tuple. - - This is intended to be used to pass features and labels for all shards from - input_fn to model_fn as the parent class `Estimator` does not have the - concept of shards. So, grouped tuple is required. - - Once called, the internal data is frozen and `append_tuple` cannot be - invoked anymore. - - Returns: - A tuple of features and labels. Both have type `_PerShardOutput`, holding - the inputs for all shards. `labels` could be `None`. - - Raises: - RuntimeError: If the internal data has not been initialized. - """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - - assert len(self._feature_list) == self._num_shards - if not self._label_list or all(l is None for l in self._label_list): - return _PerShardOutput(self._feature_list), None - - assert len(self._label_list) == self._num_shards - return (_PerShardOutput(self._feature_list), - _PerShardOutput(self._label_list)) - - def as_sharded_flattened_inputs(self): - """Flatten the features and label as tensor lists for all shards. - - Flattened tensor list contains all tensors in `features` (dict) and `labels` - (dict). Conceptually, it has the predicated structure like: - - ```python - flatten_list = [] - for name in features: - flatten_list.append(features[name]) - for name in labels: - flatten_list.append(labels[name]) - ``` - - This method handles the label is None case and single tensor case nicely. - - Once called, the internal data is frozen and `append_tuple` cannot be - invokded anymore. - - Returns: - A list of flattened inputs one for each shard. - - Raises: - RuntimeError: If the internal data has not been initialized. - ValueError: If the inputs are sharded. - """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - if not self._sharded: - raise ValueError('Inputs are not sharded.') - - sharded_inputs = [] - - for shard in range(self._num_shards): - flattened_inputs = self._as_flattened_inputs( - self._feature_list[shard], - self._label_list[shard] if self._has_labels else None) - sharded_inputs.append(flattened_inputs) - - return sharded_inputs - - def as_flattened_inputs(self): - """Flatten the features and label as a single tensor list for one host.""" - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - if self._sharded: - raise ValueError('Inputs are sharded.') - - return self._as_flattened_inputs( - self._feature_list[0], - self._label_list[0] if self._has_labels else None) - - def _as_flattened_inputs(self, features, labels): - """Flattens the `features` and `labels` to a single tensor list.""" - flattened_inputs = [] - if self._feature_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend([features[name] for name in self._feature_names]) - else: - flattened_inputs.append(features) - - if labels is not None: - if self._label_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend([labels[name] for name in self._label_names]) + 'The number of flattened tensors mismatches expected num. ' + 'Expected {}, got {}'.format(expected_num_tensors, + len(flattened_inputs))) + if self._feature_names: + unflattened_features = dict( + zip(self._feature_names, flattened_inputs[:expected_num_features])) else: - flattened_inputs.append(labels) - return flattened_inputs + # Single tensor case + unflattened_features = flattened_inputs[0] + + if expected_num_labels == 0: + unflattened_label = None + elif self._label_names: + unflattened_label = dict(zip(self._label_names, + flattened_inputs[expected_num_features:])) + else: + # Single tensor case. + unflattened_label = flattened_inputs[expected_num_features] - def unflatten_features_and_labels(self, flattened_inputs): - """Restores the flattened inputs to original features and labels form. + return unflattened_features, unflattened_label - Once called, the internal data is frozen and `append_tuple` cannot be - invokded anymore. + def __init__(self, input_fn, batch_axis, ctx): + """Constructor. Args: - flattened_inputs: Flattened inputs for one each, which should be created - by the `as_sharded_flattened_inputs` API. - - Returns: - A tuple of (`features`, `labels`), where `labels` could be None. - Each one, if present, should have identical structure (single tensor vs - dict) as the one returned by input_fn. + input_fn: input fn for train or eval. + batch_axis: A python tuple of int values describing how each tensor + produced by the Estimator `input_fn` should be split across the TPU + compute shards. + ctx: A `_TPUContext` instance with mode. Raises: - RuntimeError: If the internal data has not been initialized. - ValueError: If the number of expected tensors from `flattened_inputs` - mismatches the recorded structure. + ValueError: If both `sharded_features` and `num_cores` are `None`. """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - - expected_num_features = (len(self._feature_names) if self._feature_names - else 1) - if self._has_labels: - expected_num_labels = (len(self._label_names) if self._label_names - else 1) - else: - expected_num_labels = 0 + self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder() + + self._sharded_per_core = ctx.is_input_sharded_per_core() + self._input_fn = input_fn + self._infeed_queue = None + self._ctx = ctx + self._batch_axis = batch_axis + + def generate_infeed_enqueue_ops_and_dequeue_fn(self): + """Generates infeed enqueue ops and dequeue_fn.""" + # While tf.while_loop is called, the body function, which invokes + # `enqueue_fn` passed in, is called to construct the graph. So, input_fn + # structure is recorded. + enqueue_ops = self._invoke_input_fn_and_record_structure() + + def dequeue_fn(): + """dequeue_fn is used by TPU to retrieve the tensors.""" + values = self._infeed_queue.generate_dequeue_op() + # The unflatten process uses the structure information recorded above. + return self._inputs_structure_recorder.unflatten_features_and_labels( + values) + + return (enqueue_ops, dequeue_fn) + + def _invoke_input_fn_and_record_structure(self): + if self._sharded_per_core: + # Per-Core input pipeline deployment. + tpu_host_placement_fn = self._ctx.tpu_host_placement_function + enqueue_ops = [] + infeed_queues = [] + + # Invoke input pipeline for each core and placed on the corresponding + # host. + num_hosts = self._ctx.num_hosts + for host_id in range(num_hosts): + host_device = tpu_host_placement_fn(host_id=host_id) + with ops.device(host_device): + with ops.name_scope('input_pipeline_task%d' % (host_id)): + enqueue_ops_fn, infeed_queue_getter = ( + generate_per_core_enqueue_ops_fn_for_host( + self._ctx, self._input_fn, self._inputs_structure_recorder)) + + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + enqueue_ops.append(_wrap_computation_in_while_loop( + device=host_device, op_fn=enqueue_ops_fn)) + else: + enqueue_ops.append(enqueue_ops_fn()) + # Infeed_queue_getter must be called after enqueue_ops_fn is called. + infeed_queues.append(infeed_queue_getter()) + + # infeed_queue is used to generate dequeue ops. The only thing it uses for + # dequeue is dtypes and types. So, any one can be used. Here, grab the + # first one. + self._infeed_queue = infeed_queues[0] + return enqueue_ops - expected_num_tensors = expected_num_features + expected_num_labels - - if expected_num_tensors != len(flattened_inputs): - raise ValueError( - 'The number of flattened tensors mismatches expected num. ' - 'Expected {}, got {}'.format(expected_num_tensors, - len(flattened_inputs))) - if self._feature_names: - unflattened_features = dict(zip(self._feature_names, - flattened_inputs[:expected_num_features])) else: - # Single tensor case - unflattened_features = flattened_inputs[0] - - if expected_num_labels == 0: - unflattened_label = None - elif self._label_names: - unflattened_label = dict(zip(self._label_names, - flattened_inputs[expected_num_features:])) - else: - # Single tensor case. - unflattened_label = flattened_inputs[expected_num_features] - - return unflattened_features, unflattened_label + # TODO(b/67051042): Extend this to multi-host support. + host_id = 0 + host_device = self._ctx.tpu_host_placement_function(host_id=host_id) + def enqueue_fn(): + with ops.device(host_device): + with ops.name_scope('input_pipeline_task%d' % (host_id)): + inputs = self._input_fn() + if isinstance(inputs, tuple): + features, labels = inputs + else: + features, labels = inputs, None + self._inputs_structure_recorder.validate_and_record_structure( + features, labels) + unsharded_tensor_list = ( + self._inputs_structure_recorder.flatten_features_and_labels( + features, labels)) + + self._infeed_queue = tpu_feed.InfeedQueue( + tuple_types=[t.dtype for t in unsharded_tensor_list], + tuple_shapes=[t.shape for t in unsharded_tensor_list], + shard_dimensions=self._batch_axis) + self._infeed_queue.set_number_of_shards(self._ctx.num_cores) + + def placement_fn(core_id): + return self._ctx.tpu_host_placement_function(core_id=core_id) + return ( + self._infeed_queue.split_inputs_and_generate_enqueue_ops( + unsharded_tensor_list, + placement_function=placement_fn)) + + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + return _wrap_computation_in_while_loop(device=host_device, + op_fn=enqueue_fn) + else: + return enqueue_fn() class _ModelFnWrapper(object): @@ -788,20 +931,17 @@ class _ModelFnWrapper(object): train and eval step. """ - def __init__(self, model_fn, config, params, mode, train_batch_size, - eval_batch_size): + def __init__(self, model_fn, config, params, ctx): self._model_fn = model_fn self._config = config self._params = params - self._mode = mode - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size + self._ctx = ctx def call_without_tpu(self, features, labels): # Let CrossShardOptimizer be called without TPU in model_fn, since it's # common to set the train_op even when running evaluate() or predict(). with tpu_function.tpu_shard_context(1): - return self._call_model_fn(features, labels, use_tpu=False) + return self._call_model_fn(features, labels) def convert_to_single_tpu_train_step(self, dequeue_fn): """Converts user provided model_fn` as a single train step on TPU. @@ -831,7 +971,7 @@ class _ModelFnWrapper(object): features, labels = dequeue_fn() estimator_spec = self._verify_estimator_spec( - self._call_model_fn(features, labels, use_tpu=True)) + self._call_model_fn(features, labels)) loss, train_op = estimator_spec.loss, estimator_spec.train_op with ops.control_dependencies([train_op]): return array_ops.identity(loss) @@ -863,13 +1003,13 @@ class _ModelFnWrapper(object): A tuple of eval_fn and eval_metrics. The eval_fn representing the eval step for TPU. and eval_metrics is an `_EvalMetrics` instance. """ - eval_metrics = _EvalMetrics() + eval_metrics = _EvalMetrics(self._ctx) def eval_step(total_loss): """Evaluation step function for use inside a while loop.""" features, labels = dequeue_fn() - tpu_estimator_spec = self._call_model_fn(features, labels, use_tpu=True) + tpu_estimator_spec = self._call_model_fn(features, labels) if not isinstance(tpu_estimator_spec, TPUEstimatorSpec): raise RuntimeError( 'estimator_spec used by TPU evaluation must have type' @@ -883,11 +1023,7 @@ class _ModelFnWrapper(object): return math_ops.add(total_loss, loss) return eval_step, eval_metrics - @property - def config(self): - return self._config - - def _call_model_fn(self, features, labels, use_tpu): + def _call_model_fn(self, features, labels): """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(self._model_fn) kwargs = {} @@ -898,12 +1034,11 @@ class _ModelFnWrapper(object): if 'labels' in model_fn_args: kwargs['labels'] = labels - else: - if labels is not None: - raise ValueError( - 'model_fn does not take labels, but input_fn returns labels.') + elif labels is not None: + raise ValueError( + 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: - kwargs['mode'] = self._mode + kwargs['mode'] = self._ctx.mode if 'config' in model_fn_args: kwargs['config'] = config if 'params' in model_fn_args: @@ -914,16 +1049,16 @@ class _ModelFnWrapper(object): 'model_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params[\'batch_size\']'.format(self._model_fn)) - if self._mode == model_fn_lib.ModeKeys.TRAIN: - params[_BATCH_SIZE_KEY] = _per_shard_batch_size( - self._train_batch_size, config, use_tpu) - elif (self._mode == model_fn_lib.ModeKeys.EVAL and - self._eval_batch_size is not None): - params[_BATCH_SIZE_KEY] = _per_shard_batch_size( - self._eval_batch_size, config, use_tpu) + + batch_size_for_model_fn = self._ctx.batch_size_for_model_fn + if batch_size_for_model_fn is not None: + params[_BATCH_SIZE_KEY] = batch_size_for_model_fn estimator_spec = self._model_fn(features=features, **kwargs) - if (not use_tpu) and isinstance(estimator_spec, TPUEstimatorSpec): + if (self._ctx.is_running_on_cpu() and + isinstance(estimator_spec, TPUEstimatorSpec)): + # The estimator_spec will be passed to `Estimator` directly, which expects + # type `EstimatorSpec`. return estimator_spec.as_estimator_spec() else: return estimator_spec @@ -946,7 +1081,8 @@ class _ModelFnWrapper(object): class _EvalMetrics(object): """Class wraps TPUEstimator.eval_metrics.""" - def __init__(self): + def __init__(self, ctx): + self._ctx = ctx self._metric_fn = None self._is_dict = False self._tensor_keys = [] @@ -970,8 +1106,6 @@ class _EvalMetrics(object): if isinstance(eval_metrics[1], (tuple, list)): fn_args = util.fn_args(eval_metrics[0]) - if 'self' in fn_args: - fn_args = tuple([arg for arg in fn_args if arg != 'self']) if len(eval_metrics[1]) != len(fn_args): raise RuntimeError( 'In TPUEstimatorSpec.eval_metrics, length of tensors does not ' @@ -1029,7 +1163,7 @@ class _EvalMetrics(object): raise RuntimeError('Eval metrics have not been recorded yet') return self._tensors - def to_metric_metric_ops_for_tpu(self, run_config, dummy_update_op): + def to_metric_metric_ops_for_tpu(self, dummy_update_op): """Creates the eval_metric_ops now based on the TPU outfeed. `eval_metric_ops` is defined in `EstimatorSpec`. From all shards, tensors @@ -1038,7 +1172,6 @@ class _EvalMetrics(object): metric fn. Args: - run_config: A `RunConfig` instance. dummy_update_op: A dummy update op. Returns: @@ -1050,9 +1183,7 @@ class _EvalMetrics(object): RuntimeError: If outfeed tensor is scalar. """ - num_shards = run_config.tpu_config.num_shards - job = _tpu_job(run_config, model_fn_lib.ModeKeys.EVAL) - job_device = '' if job is None else ('/job:%s' % job) + num_cores = self._ctx.num_cores # For each i, dequeue_ops[i] is a list containing the tensors from all # shards. This list is concatenated later. @@ -1061,8 +1192,9 @@ class _EvalMetrics(object): dequeue_ops.append([]) # Outfeed ops execute on each JF node. - for i in xrange(num_shards): - with ops.device('%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8)): + tpu_device_placement_fn = self._ctx.tpu_device_placement_function + for i in xrange(num_cores): + with ops.device(tpu_device_placement_fn(i)): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( dtypes=self._tensor_dtypes, shapes=self._tensor_shapes) for j, item in enumerate(outfeed_tensors): @@ -1070,7 +1202,7 @@ class _EvalMetrics(object): # It is assumed evaluation always happends on single host TPU system. So, # place all ops on tpu host if possible. - with ops.device('{}/device:CPU:0'.format(job_device)): + with ops.device(self._ctx.tpu_host_placement_function(core_id=0)): for i, item in enumerate(dequeue_ops): if dequeue_ops[i][0].shape.ndims == 0: raise RuntimeError( @@ -1115,9 +1247,9 @@ class TPUEstimator(estimator_lib.Estimator): specify `train_batch_size` in constructor, and then get the batch size for each shard in `input_fn` and `model_fn` by `params['batch_size']`. If `TPUConfig.per_host_input_for_training` is `True`, `input_fn` is invoked per - host rather than per shard. In this case, a global batch size is transformed a + host rather than per core. In this case, a global batch size is transformed a per-host batch size in params for `input_fn`, but `model_fn` still gets - per-shard batch size. + per-core batch size. For evaluation, if `eval_batch_size` is None, it is executed on CPU, even if `use_tpu` is `True`. If `eval_batch_size` is not `None`, it is executed on @@ -1275,9 +1407,7 @@ class TPUEstimator(estimator_lib.Estimator): # We cannot store config and params in this constructor as parent # constructor might change them, such as assigning a temp dir for # config.model_dir. - model_function = _augment_model_fn(model_fn, train_batch_size, - eval_batch_size, use_tpu, - batch_axis) + model_function = self._augment_model_fn(model_fn, batch_axis) # Passing non-None params as wrapped model_fn has it. params = params or {} @@ -1286,12 +1416,13 @@ class TPUEstimator(estimator_lib.Estimator): model_dir=model_dir, config=config, params=params) - self._use_tpu = use_tpu - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size self._iterations_per_training_loop = ( self._config.tpu_config.iterations_per_loop) + # All properties passed to _TPUContext are immutable. + self._ctx = _TPUContext(self._config, train_batch_size, eval_batch_size, + use_tpu) + def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -1307,10 +1438,10 @@ class TPUEstimator(estimator_lib.Estimator): return _create_global_step(graph) def _convert_train_steps_to_hooks(self, steps, max_steps): - if _is_running_on_cpu(self._use_tpu, model_fn_lib.ModeKeys.TRAIN, - self._eval_batch_size): - return super(TPUEstimator, self)._convert_train_steps_to_hooks( - steps, max_steps) + with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx: + if ctx.is_running_on_cpu(): + return super(TPUEstimator, self)._convert_train_steps_to_hooks( + steps, max_steps) # On TPU. if steps is None and max_steps is None: @@ -1318,18 +1449,24 @@ class TPUEstimator(estimator_lib.Estimator): 'For TPU training, one of `steps` or `max_steps` must be set. ' 'Cannot be both `None`.') + # Estimator.train has explicit positiveness check. + if steps is not None: + util_lib.check_positive_integer(steps, 'Train steps') + if max_steps is not None: + util_lib.check_positive_integer(max_steps, 'Train max_steps') + return [_TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps)] def _convert_eval_steps_to_hooks(self, steps): - if _is_running_on_cpu(self._use_tpu, model_fn_lib.ModeKeys.EVAL, - self._eval_batch_size): - return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) + with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx: + if ctx.is_running_on_cpu(): + return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) if steps is None: raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.') - if steps <= 0: - raise ValueError('Must specify steps > 0, given: {}'.format(steps)) + + util_lib.check_positive_integer(steps, 'Eval steps') hooks = [] hooks.append(evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access @@ -1364,197 +1501,115 @@ class TPUEstimator(estimator_lib.Estimator): if 'config' in input_fn_args: kwargs['config'] = config - # Setting the batch size in params first. This helps user to have same - # input_fn for use_tpu=True/False. - if mode == model_fn_lib.ModeKeys.TRAIN: - kwargs['params'][_BATCH_SIZE_KEY] = ( - _per_shard_batch_size(self._train_batch_size, config, self._use_tpu) - if not config.tpu_config.per_host_input_for_training else - self._train_batch_size) - elif (mode == model_fn_lib.ModeKeys.EVAL and - self._eval_batch_size is not None): - # For TPU evaluation, input_fn is invoked for one host (instead of shard). - kwargs['params'][_BATCH_SIZE_KEY] = self._eval_batch_size - - if _is_running_on_cpu(self._use_tpu, mode, self._eval_batch_size): - with ops.device('/device:CPU:0'): - return input_fn(**kwargs) - - job = _tpu_job(config, mode) - def placement_function(index): - if job is None: - return '/replica:0/task:0/device:CPU:0' - else: - return '/job:%s/task:%d/device:CPU:0' % (job, index / 8) + with self._ctx.with_mode(mode) as ctx: + # Setting the batch size in params first. This helps user to have same + # input_fn for use_tpu=True/False. + batch_size_for_input_fn = ctx.batch_size_for_input_fn + if batch_size_for_input_fn is not None: + kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn - if mode == model_fn_lib.ModeKeys.TRAIN: - if not config.tpu_config.per_host_input_for_training: - # Now for TPU training. - num_shards = config.tpu_config.num_shards - inputs = _InputsHolder(num_shards=num_shards) - for i in range(config.tpu_config.num_shards): - with ops.device(placement_function(i)): - inputs.append_tuple(input_fn(**kwargs)) - return inputs.as_features_and_labels_tuple() - else: - # TODO(xiejw): Extend this to multi-host support. - with ops.device(placement_function(0)): + if ctx.is_running_on_cpu(): + with ops.device('/device:CPU:0'): return input_fn(**kwargs) - # Now for TPU evaluation. - with ops.device(placement_function(0)): - return input_fn(**kwargs) - - -# TODO(b/64607814): Ensure batch_axis works with nested structures. -def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder, run_config, - batch_axis, mode): - """Utility to convert input_fn to enqueue and dequeue fns for TPU. - - Args: - inputs_holder: An `_InputsHolder` holding features and labels. - run_config: A `RunConfig` instance. - batch_axis: A python list of batch dimensions. - mode: ModeKeys - - Returns: - A tuple of (dequeue_fn, enqueue_fn) - """ - if inputs_holder.sharded: - sharded_inputs = inputs_holder.as_sharded_flattened_inputs() - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(sharded_inputs[0])) - infeed_queue.set_configuration_from_sharded_input_tensors(sharded_inputs) - else: - unsharded_inputs = inputs_holder.as_flattened_inputs() - infeed_queue = tpu_feed.InfeedQueue( - tuple_types=[t.dtype for t in unsharded_inputs], - tuple_shapes=[t.shape for t in unsharded_inputs], - shard_dimensions=batch_axis) - infeed_queue.set_number_of_shards(inputs_holder.num_shards) - - def dequeue_fn(): - """dequeue_fn is used by the train_step in TPU to retrieve the tensors.""" - values = infeed_queue.generate_dequeue_op() - return inputs_holder.unflatten_features_and_labels(values) - - def tpu_ordinal_function(index): - """Return the TPU ordinal associated with a shard. - - Required because the enqueue ops are placed on CPU. - - Args: - index: the shard index - - Returns: - The ordinal of the TPU device the shard's infeed should be placed on. - """ - return index % 8 - - def enqueue_fn(): - """enqueue_fn is used to add ops to the graph to send tensors.""" - if inputs_holder.sharded: - return infeed_queue.generate_enqueue_ops( - sharded_inputs, tpu_ordinal_function=tpu_ordinal_function) - else: - job = _tpu_job(run_config, mode) - def placement_function(index): - if job is None: - return '/replica:0/task:0/device:CPU:0' - else: - # This assumes that if using more than 8 shards, - # the job configuration varies 'task'. - return '/job:%s/task:%d/device:CPU:0' % (job, index / 8) - return infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_inputs, placement_function=placement_function) - - return (dequeue_fn, enqueue_fn) - - -def _augment_model_fn(model_fn, train_batch_size, eval_batch_size, use_tpu, - batch_axis): - """Returns a new model_fn, which wraps the TPU support.""" - - def _model_fn(features, labels, mode, config, params): - """A Estimator `model_fn` for TPUEstimator.""" - model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, mode, - train_batch_size, eval_batch_size) - - # TODO(jhseu): Move to PREDICT to TPU. - if _is_running_on_cpu(use_tpu, mode, eval_batch_size): - logging.info('Running %s on CPU', mode) - return model_fn_wrapper.call_without_tpu(features, labels) - - inputs = _InputsHolder(features=features, labels=labels, - num_shards=config.tpu_config.num_shards) - - dequeue_fn, enqueue_fn = _create_infeed_enqueue_ops_and_dequeue_fn( - inputs, config, batch_axis, mode) - - if mode == model_fn_lib.ModeKeys.TRAIN: - loss = _train_on_tpu_system(model_fn_wrapper, dequeue_fn) - hooks = [ - TPUInfeedOutfeedSessionHook(config, mode, enqueue_fn), - training.LoggingTensorHook( - {'loss': array_ops.identity(loss), - 'step': training.get_global_step()}, - every_n_secs=30) - ] - summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) - with ops.control_dependencies([loss]): - update_ops = _sync_variables_ops() - - # Validate the TPU training graph to catch basic errors - _validate_tpu_training_graph() - - return model_fn_lib.EstimatorSpec( - mode, - loss=loss, - training_hooks=hooks, - train_op=control_flow_ops.group(*update_ops)) - - # Now eval. - total_loss, eval_metric_ops = _eval_on_tpu_system( - model_fn_wrapper, dequeue_fn) - iterations_per_loop_var = _create_iterations_per_loop() - mean_loss = math_ops.div( - total_loss, - math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) - - # Creates a dummy metric update_op for all metrics. Estimator expects all - # metrics in eval_metric_ops have update_op and calls them one by one. The - # real metric update_ops are invoked in a separated thread. So, here give - # Estimator the dummy op for all metrics. - with ops.control_dependencies([mean_loss]): - # After TPU evaluation computation is done (the mean_loss tensor), reads - # all variables back from TPU and updates the eval step counter properly. - internal_ops_to_run = _sync_variables_ops() - internal_ops_to_run.append( - _increase_eval_step_op(iterations_per_loop_var)) - with ops.control_dependencies(internal_ops_to_run): - dummy_update_op = control_flow_ops.no_op() - - eval_metric_ops, eval_update_ops = ( - eval_metric_ops.to_metric_metric_ops_for_tpu( - config, dummy_update_op)) - hooks = [ - TPUInfeedOutfeedSessionHook(config, mode, enqueue_fn, eval_update_ops), - ] - - return model_fn_lib.EstimatorSpec( - mode, - loss=mean_loss, - evaluation_hooks=hooks, - eval_metric_ops=eval_metric_ops) - return _model_fn - - -def _eval_on_tpu_system(model_fn_wrapper, dequeue_fn): + # For TPU computation, input_fn should be invoked in a tf.while_loop for + # performance. While constructing the tf.while_loop, the structure of + # inputs returned by the `input_fn` needs to be recorded. The structure + # includes whether features or labels is dict or single Tensor, dict keys, + # tensor shapes, and dtypes. The recorded structure is used to create the + # infeed dequeue ops, which must be wrapped and passed as a Fn, called + # inside the TPU computation, as the TPU computation is wrapped inside a + # tf.while_loop also. So, we either pass input_fn to model_fn or pass + # dequeue_fn to model_fn. Here, `input_fn` is passed directly as + # `features` in `model_fn` signature. + def _input_fn(): + return input_fn(**kwargs) + return _input_fn + + def _augment_model_fn(self, model_fn, batch_axis): + """Returns a new model_fn, which wraps the TPU support.""" + + def _model_fn(features, labels, mode, config, params): + """A Estimator `model_fn` for TPUEstimator.""" + with self._ctx.with_mode(mode) as ctx: + model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) + + # TODO(jhseu): Move to PREDICT to TPU. + if ctx.is_running_on_cpu(): + logging.info('Running %s on CPU', mode) + return model_fn_wrapper.call_without_tpu(features, labels) + + assert labels is None, '`labels` passed to `model_fn` must be `None`.' + # TPUEstimator._call_input_fn passes `input_fn` as features to here. + assert callable(features), '`input_fn` is not callable.' + input_fn = features + + input_holders = _InputPipeline(input_fn, batch_axis, ctx) + enqueue_ops, dequeue_fn = ( + input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) + + if mode == model_fn_lib.ModeKeys.TRAIN: + loss = _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn) + hooks = [ + TPUInfeedOutfeedSessionHook(ctx, enqueue_ops), + training.LoggingTensorHook( + {'loss': array_ops.identity(loss), + 'step': training.get_global_step()}, + every_n_secs=30) + ] + summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) + with ops.control_dependencies([loss]): + update_ops = _sync_variables_ops() + + # Validate the TPU training graph to catch basic errors + _validate_tpu_training_graph() + + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + training_hooks=hooks, + train_op=control_flow_ops.group(*update_ops)) + + # Now eval. + total_loss, eval_metric_ops = _eval_on_tpu_system( + ctx, model_fn_wrapper, dequeue_fn) + iterations_per_loop_var = _create_or_get_iterations_per_loop() + mean_loss = math_ops.div( + total_loss, + math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) + + # Creates a dummy metric update_op for all metrics. Estimator expects + # all metrics in eval_metric_ops have update_op and calls them one by + # one. The real metric update_ops are invoked in a separated thread. So, + # here give Estimator the dummy op for all metrics. + with ops.control_dependencies([mean_loss]): + # After TPU evaluation computation is done (the mean_loss tensor), + # reads all variables back from TPU and updates the eval step counter + # properly + internal_ops_to_run = _sync_variables_ops() + internal_ops_to_run.append( + _increase_eval_step_op(iterations_per_loop_var)) + with ops.control_dependencies(internal_ops_to_run): + dummy_update_op = control_flow_ops.no_op() + + eval_metric_ops, eval_update_ops = ( + eval_metric_ops.to_metric_metric_ops_for_tpu(dummy_update_op)) + hooks = [ + TPUInfeedOutfeedSessionHook(ctx, enqueue_ops, eval_update_ops), + ] + + return model_fn_lib.EstimatorSpec( + mode, + loss=mean_loss, + evaluation_hooks=hooks, + eval_metric_ops=eval_metric_ops) + return _model_fn + + +def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - config = model_fn_wrapper.config.tpu_config - num_shards = config.num_shards - iterations_per_loop_var = _create_iterations_per_loop() + num_cores = ctx.num_cores + iterations_per_loop_var = _create_or_get_iterations_per_loop() single_tpu_eval_step, eval_metric_ops = ( model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)) @@ -1567,15 +1622,15 @@ def _eval_on_tpu_system(model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard(multi_tpu_eval_steps_on_single_shard, inputs=[], - num_shards=num_shards, + num_shards=num_cores, outputs_from_all_shards=False) return loss, eval_metric_ops -def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): +def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - num_shards = model_fn_wrapper.config.tpu_config.num_shards - iterations_per_loop_var = _create_iterations_per_loop() + num_cores = ctx.num_cores + iterations_per_loop_var = _create_or_get_iterations_per_loop() single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step( dequeue_fn) @@ -1589,11 +1644,27 @@ def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard, inputs=[], - num_shards=num_shards, + num_shards=num_cores, outputs_from_all_shards=False) return loss +def _wrap_computation_in_while_loop(device, op_fn): + """Wraps the ops generated by `op_fn` in tf.while_loop.""" + def computation(i): + with ops.control_dependencies(op_fn()): + return i + 1 + + iterations_per_loop_var = _create_or_get_iterations_per_loop() + # By setting parallel_iterations=1, the parallel execution in while_loop is + # basically turned off. + with ops.device(device): + iterations = array_ops.identity(iterations_per_loop_var) + return control_flow_ops.while_loop( + lambda i: i < iterations, + computation, [constant_op.constant(0)], parallel_iterations=1) + + def _validate_tpu_training_graph(): """Validate graph before running distributed training. @@ -1609,3 +1680,5 @@ def _validate_tpu_training_graph(): if not cross_replica_sum_ops: raise ValueError( 'CrossShardOptimizer must be used for model training on TPUs.') + + diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py index d545a94ca6a2fdb3a9df2748b59300fd141dc55d..f8ba7d45e20b2f48e1409427665878df40a6db02 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py @@ -177,6 +177,10 @@ class ShardingPolicy(object): raise ValueError("shape %s does not contain shard_dimension %d" % (shape.as_list(), self._shard_dimension)) dims = shape.as_list() + if dims[self._shard_dimension] is None: + raise ValueError("shape %s must have a fixed size for dimension %d " + "that is known at graph construction time." % + (shape.as_list(), self._shard_dimension)) if (dims[self._shard_dimension] % self._number_of_shards) != 0: raise ValueError("shape %s cannot be sharded %d ways along dimension %d" % (shape.as_list(), self._number_of_shards, diff --git a/tensorflow/contrib/tpu/python/tpu/util.py b/tensorflow/contrib/tpu/python/tpu/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b8ea307d8900cf1b6d1e6e808d0b9ede26f86490 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/util.py @@ -0,0 +1,31 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== + +"""Utilities for the functionalities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + + +def check_positive_integer(value, name): + """Checks whether `value` is a positive integer.""" + if not isinstance(value, six.integer_types): + raise TypeError('{} must be int, got {}'.format(name, type(value))) + + if value <= 0: + raise ValueError('{} must be positive, got {}'.format(name, value)) diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 8e3d869a51c440e00059851f05f6ed2fe5558416..0df5ff50c0da4a8ccf344efb672db8dbc69b72da 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -26,6 +26,7 @@ py_library( "python/training/resample.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", + "python/training/sgdr_learning_rate_decay.py", "python/training/training.py", "python/training/tuner.py", ], @@ -263,6 +264,7 @@ py_test( srcs = ["python/training/training_test.py"], shard_count = 3, srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":training_py", "//tensorflow/contrib/framework:framework_py", diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py index 5523cc375fc20dc167fee0eaa6f1682dc1892c3f..95fbc50cba73b25b748c31ecd443eb19c0b6fc8a 100644 --- a/tensorflow/contrib/training/python/training/bucket_ops.py +++ b/tensorflow/contrib/training/python/training/bucket_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -47,7 +48,6 @@ _dtypes = input_py._dtypes _store_sparse_tensors = input_py._store_sparse_tensors _validate_keep_input = input_py._validate_keep_input _shapes = input_py._shapes -_smart_cond = input_py._smart_cond _which_queue = input_py._which_queue # pylint: enable=protected-access @@ -239,7 +239,7 @@ def bucket(tensors, ] return control_flow_ops.group(*enqueues, name="group_enqueues") - maybe_enqueue = _smart_cond( + maybe_enqueue = utils.smart_cond( keep_input, enqueue_which, control_flow_ops.no_op) diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 119fa3824bd77724471768980783e105d5595c4b..391899b34f90be25e10450ebf4e285ed2d39446f 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -25,6 +25,7 @@ import six from tensorflow.contrib.training.python.training import hparam_pb2 from tensorflow.python.framework import ops from tensorflow.python.util import compat +from tensorflow.python.util import deprecation # Define the regular expression for parsing a single clause of the input # (delimited by commas). A legal clause looks like: @@ -138,7 +139,7 @@ def _process_list_value(name, parse_fn, var_type, m_dict, values, def parse_values(values, type_map): - """Parses hyperparameter values from a string into a python map.. + """Parses hyperparameter values from a string into a python map. `values` is a string containing comma-separated `name=value` pairs. For each pair, the value of the hyperparameter named `name` is set to @@ -470,24 +471,29 @@ class HParams(object): type_map[name] = param_type values_map = parse_values(values, type_map) - return self.set_from_map(values_map) + return self.override_from_dict(values_map) - def set_from_map(self, values_map): + def override_from_dict(self, values_dict): """Override hyperparameter values, parsing new values from a dictionary. Args: - values_map: Dictionary of name:value pairs. + values_dict: Dictionary of name:value pairs. Returns: The `HParams` instance. Raises: - ValueError: If `values_map` cannot be parsed. + ValueError: If `values_dict` cannot be parsed. """ - for name, value in values_map.items(): + for name, value in values_dict.items(): self.set_hparam(name, value) return self + @deprecation.deprecated(None, 'Use `override_from_dict`.') + def set_from_map(self, values_map): + """DEPRECATED. Use override_from_dict.""" + return self.override_from_dict(values_dict=values_map) + def set_model_structure(self, model_structure): self._model_structure = model_structure @@ -515,7 +521,7 @@ class HParams(object): ValueError: If `values_json` cannot be parsed. """ values_map = json.loads(values_json) - return self.set_from_map(values_map) + return self.override_from_dict(values_map) def values(self): """Return the hyperparameter values as a Python dictionary. @@ -526,6 +532,9 @@ class HParams(object): """ return {n: getattr(self, n) for n in self._hparam_types.keys()} + def __contains__(self, key): + return key in self._hparam_types + def __str__(self): return str(sorted(self.values().items())) diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index b01116a2139f76bab2e6219048c7c1aec013e626..f54514cefd39cab93e5c3a34786a6bb751b97704 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -32,6 +32,11 @@ class HParamsTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('xyz=123') + def testContains(self): + hparams = hparam.HParams(foo=1) + self.assertTrue('foo' in hparams) + self.assertFalse('bar' in hparams) + def testSomeValues(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6') self.assertDictEqual({'aaa': 1, 'b': 2.0, 'c_c': 'relu6'}, hparams.values()) @@ -93,11 +98,11 @@ class HParamsTest(test.TestCase): def testSetFromMap(self): hparams = hparam.HParams(a=1, b=2.0, c='tanh') - hparams.set_from_map({'a': -2, 'c': 'identity'}) + hparams.override_from_dict({'a': -2, 'c': 'identity'}) self.assertDictEqual({'a': -2, 'c': 'identity', 'b': 2.0}, hparams.values()) hparams = hparam.HParams(x=1, b=2.0, d=[0.5]) - hparams.set_from_map({'d': [0.1, 0.2, 0.3]}) + hparams.override_from_dict({'d': [0.1, 0.2, 0.3]}) self.assertDictEqual({'d': [0.1, 0.2, 0.3], 'x': 1, 'b': 2.0}, hparams.values()) diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc index a1fbea57dd1202c1a22e6b3570e9378555fe3498..cff765d1e832e5a593462283444d7c4ed7831636 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc @@ -43,21 +43,21 @@ VerbsService::Stub::Stub( const std::shared_ptr< ::grpc::ChannelInterface>& channel) : channel_(channel), rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0], - ::grpc::internal::RpcMethod::NORMAL_RPC, + ::grpc::RpcMethod::NORMAL_RPC, channel) {} ::grpc::Status VerbsService::Stub::GetRemoteAddress( ::grpc::ClientContext* context, const GetRemoteAddressRequest& request, GetRemoteAddressResponse* response) { - return ::grpc::internal::BlockingUnaryCall( + return ::grpc::BlockingUnaryCall( channel_.get(), rpcmethod_GetRemoteAddress_, context, request, response); } VerbsService::AsyncService::AsyncService() { for (int i = 0; i < 1; ++i) { - AddMethod(new ::grpc::internal::RpcServiceMethod( + AddMethod(new ::grpc::RpcServiceMethod( grpcVerbsService_method_names[i], - ::grpc::internal::RpcMethod::NORMAL_RPC, + ::grpc::RpcMethod::NORMAL_RPC, nullptr)); ::grpc::Service::MarkMethodAsync(i); } diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h index 86431ca030c38c56155801202714ee4a49b764df..6e2bf86dac2aa84ff453aaefbfc57cd3ee8bc1fd 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h @@ -28,6 +28,15 @@ limitations under the License. #include "tensorflow/contrib/verbs/verbs_service.pb.h" namespace grpc { + +// ensure internal namespace exists +namespace internal { +// bring in contents of external namespace +using namespace ::grpc; +} // namespace internal +// bring in contents of internal namespace +using namespace internal; + class CompletionQueue; class Channel; class RpcService; @@ -61,7 +70,7 @@ class VerbsService GRPC_FINAL { private: std::shared_ptr< ::grpc::ChannelInterface> channel_; - const ::grpc::internal::RpcMethod rpcmethod_GetRemoteAddress_; + const ::grpc::RpcMethod rpcmethod_GetRemoteAddress_; }; static std::unique_ptr NewStub( const std::shared_ptr< ::grpc::ChannelInterface>& channel, diff --git a/tensorflow/contrib/xla_tf_graph/BUILD b/tensorflow/contrib/xla_tf_graph/BUILD deleted file mode 100644 index 4a3a2de9b5e58cfab2e6f8de5c6789f1cbcebde7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/xla_tf_graph/BUILD +++ /dev/null @@ -1,67 +0,0 @@ -# Description: -# contains parts of TensorFlow that are experimental or unstable and which are not supported. - -package( - default_visibility = ["//visibility:public"], -) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) - -cc_library( - name = "xla_tf_graph_util", - srcs = [ - "xla_tf_graph_util.cc", - ], - hdrs = [ - "xla_tf_graph_util.h", - ], - deps = [ - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "xla_tf_graph_util_test", - srcs = ["xla_tf_graph_util_test.cc"], - linkstatic = 1, - tags = ["nomac"], # b/63908145 - deps = [ - ":xla_tf_graph_util", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:scope", - "//tensorflow/compiler/jit:xla_cpu_jit", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework_internal", - "//tensorflow/core:ops", - "//tensorflow/core:tensorflow", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/kernels:cwise_op", - ], -) diff --git a/tensorflow/contrib/xla_tf_graph/README.md b/tensorflow/contrib/xla_tf_graph/README.md deleted file mode 100644 index a374189e813107bcf3fe71032d4baf16b3d164a2..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/xla_tf_graph/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# Xla Tf Graph - -## Description - -This module contains utilities to treat xla representation as tf graph to support mobile SOC experiments and leverage tf tools. - -Maintainers: -- Satoshi Kataoka (satok@google.com, github.com/satok16) diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc deleted file mode 100644 index 302aa6457ab08a30bca9c28a5f162331111c4b77..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc +++ /dev/null @@ -1,247 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h" - -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace tensorflow { -namespace xla_tf_graph { - -namespace { - -constexpr const char* const GRAPH_NAME = "xla_tf_graph"; -constexpr const char* const NODE_NAME_PREFIX = "xla"; - -Status ConvertPrimitiveTypeToDataType(const xla::PrimitiveType p_type, - DataType* d_type) { - switch (p_type) { - case xla::PRED: - *d_type = DT_BOOL; - return Status::OK(); - case xla::S8: - *d_type = DT_INT8; - return Status::OK(); - case xla::S16: - *d_type = DT_INT16; - return Status::OK(); - case xla::S32: - *d_type = DT_INT32; - return Status::OK(); - case xla::S64: - *d_type = DT_INT64; - return Status::OK(); - case xla::U8: - *d_type = DT_UINT8; - return Status::OK(); - case xla::U16: - *d_type = DT_UINT16; - return Status::OK(); - case xla::F16: - *d_type = DT_HALF; - return Status::OK(); - case xla::F32: - *d_type = DT_FLOAT; - return Status::OK(); - case xla::F64: - *d_type = DT_DOUBLE; - return Status::OK(); - default: - return errors::InvalidArgument( - "Unsupported PrimitiveType in ConvertPrimitiveTypeToDataType ", - xla::PrimitiveType_Name(p_type)); - } -} - -Status ConvertXlaShapeToTensorShapeType(const xla::Shape& xla_shape, - std::vector* tensor_shapes, - std::vector* data_types) { - switch (xla_shape.element_type()) { - case xla::TUPLE: { - for (const xla::Shape& element_shape : xla_shape.tuple_shapes()) { - if (element_shape.element_type() == xla::TUPLE) { - return errors::InvalidArgument("Nested tuple is not allowed."); - } - TF_RETURN_IF_ERROR(ConvertXlaShapeToTensorShapeType( - element_shape, tensor_shapes, data_types)); - } - return Status::OK(); - } - case xla::PRED: - case xla::S8: - case xla::S16: - case xla::S32: - case xla::S64: - case xla::U8: - case xla::U16: - case xla::U32: - case xla::U64: - case xla::F16: - case xla::F32: - case xla::F64: { - TensorShape shape; - DataType type; - TF_RETURN_IF_ERROR( - ConvertPrimitiveTypeToDataType(xla_shape.element_type(), &type)); - for (const int64& dim : xla_shape.dimensions()) { - shape.AddDim(dim); - } - tensor_shapes->emplace_back(shape); - data_types->emplace_back(type); - return Status::OK(); - } - default: - return errors::InvalidArgument( - "Unsupported PrimitiveType in ConvertXlaShapeToTensorShapeType ", - xla::PrimitiveType_Name(xla_shape.element_type())); - } -} - -string BuildXlaNodeName(const xla::OperationRequest& operation_request, - const string& xla_op_type, const string& suffix) { - const string name = strings::StrCat( - NODE_NAME_PREFIX, "/", operation_request.output_handle().handle(), "/", - xla_op_type); - if (suffix.empty()) { - return name; - } else { - return strings::StrCat(name, "/", suffix); - } -} - -string BuildXlaNodeName(const xla::OperationRequest& operation_request, - const string& xla_op_type) { - return BuildXlaNodeName(operation_request, xla_op_type, ""); -} - -string BuildXlaNodeOp(const protobuf::Message& msg, const string& suffix) { - return strings::StrCat(msg.GetDescriptor()->name(), "/", suffix); -} - -string BuildXlaNodeOp(const protobuf::Message& msg) { - return BuildXlaNodeOp(msg, ""); -} - -Status ConvertOpRequestToXlaNode(const xla::OperationRequest& operation_request, - XlaNode* xla_node) { - const xla::OpRequest& op_request = operation_request.request(); - switch (op_request.op_case()) { - case xla::OpRequest::kBinaryOpRequest: { - const xla::BinaryOpRequest& op = op_request.binary_op_request(); - xla_node->op_type = - BuildXlaNodeOp(op, xla::BinaryOperation_Name(op.binop())); - xla_node->name = BuildXlaNodeName(operation_request, xla_node->op_type); - xla_node->input_ids.emplace_back(std::make_tuple(op.lhs().handle(), 0)); - xla_node->input_ids.emplace_back(std::make_tuple(op.rhs().handle(), 0)); - for (const int64& dim : op.broadcast_dimensions()) { - xla_node->broadcast_dimensions.emplace_back(dim); - } - break; - } - case xla::OpRequest::kParameterRequest: { - const xla::ParameterRequest& op = op_request.parameter_request(); - xla_node->op_type = BuildXlaNodeOp(op, ""); - xla_node->name = - BuildXlaNodeName(operation_request, xla_node->op_type, op.name()); - break; - } - case xla::OpRequest::kVariadicOpRequest: { - const xla::VariadicOpRequest& op = op_request.variadic_op_request(); - xla_node->op_type = - BuildXlaNodeOp(op, xla::VariadicOperation_Name(op.varop())); - xla_node->name = BuildXlaNodeName(operation_request, xla_node->op_type); - for (const xla::ComputationDataHandle& handle : op.operands()) { - xla_node->input_ids.emplace_back(std::make_tuple(handle.handle(), 0)); - } - break; - } - case xla::OpRequest::kGetTupleElementRequest: { - const xla::GetTupleElementRequest& op = - op_request.get_tuple_element_request(); - xla_node->op_type = BuildXlaNodeOp(op); - xla_node->name = BuildXlaNodeName(operation_request, xla_node->op_type); - xla_node->input_ids.emplace_back( - std::make_tuple(op.operand().handle(), op.index())); - break; - } - default: - // TODO(satok): Implement all possible cases. - LOG(FATAL) << "Op request: " << op_request.op_case() - << " is not supported yet."; - break; - } - - CHECK(!xla_node->name.empty()); - CHECK(!xla_node->op_type.empty()); - - TF_RETURN_IF_ERROR(ConvertXlaShapeToTensorShapeType( - operation_request.output_shape(), &xla_node->output_shapes, - &xla_node->output_data_types)); - return Status::OK(); -} - -void SetupXlaCpuClient(std::unique_ptr* flib_def, - std::unique_ptr* compiler) { - xla::Client* client = xla::ClientLibrary::LocalClientOrDie(); - XlaOpRegistry::RegisterCompilationKernels(); - - FunctionDefLibrary flib; - flib_def->reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); - - // Setup compiler options - XlaCompiler::Options options; - DeviceType device_type(DEVICE_CPU_XLA_JIT); - options.device_type = &device_type; - options.flib_def = flib_def->get(); - options.client = client; - compiler->reset(new XlaCompiler(options)); -} - -} // namespace - -xla::StatusOr> -ConvertTfGraphToXlaSessionModule(const std::vector& args, - std::unique_ptr graph) { - CHECK(graph); - - std::unique_ptr flib_def; - std::unique_ptr compiler; - - SetupXlaCpuClient(&flib_def, &compiler); - - // Compile graph and build computation - XlaCompiler::CompilationResult result; - TF_CHECK_OK(compiler->CompileGraph(XlaCompiler::CompileOptions(), GRAPH_NAME, - std::move(graph), args, &result)); - - return result.computation->Snapshot(); -} - -xla::StatusOr> -ConvertXlaSessionModuleToXlaNodes(const xla::SessionModule& session_module) { - std::unordered_map xla_nodes; - for (const auto& operation_request : session_module.entry().requests()) { - XlaNode xla_node; - TF_RETURN_IF_ERROR( - ConvertOpRequestToXlaNode(operation_request.second, &xla_node)); - xla_nodes.emplace(operation_request.first, xla_node); - } - return std::move(xla_nodes); -} - -} // namespace xla_tf_graph -} // namespace tensorflow diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h deleted file mode 100644 index e635290851f7e5d078d98d845e7488fc3cd94049..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_ -#define TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_ - -#include - -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace xla_tf_graph { - -// A set of utilities to handle xla computation requests. -// These utilities help developers leverage existing tools to work with -// xla computations, also provide a way to support TensorFlow ops by -// implementing xla computations so that they can do experiments on their -// specialized environments. - -// A structure to represent typed attributes of TensorFlow graph node. -// This structure contains op specific attributes as members so that -// we can treat them explicitly. -struct XlaNode { - // Unique node name - string name; - // Op type of xla computation - string op_type; - // List of pair of unique id and port of input node. - // We store this value instead - // of node name in order not to wait for all XlaNodes to be constructed. - std::vector> input_ids; - // Oputput shapes - std::vector output_shapes; - // Output data types - std::vector output_data_types; - - //--------------------------- - // Op specific attributes - // #xla::OpRequest::kBinaryOpRequest - std::vector broadcast_dimensions; -}; - -// Convert a tf graph to a xla session module -xla::StatusOr> -ConvertTfGraphToXlaSessionModule(const std::vector& args, - std::unique_ptr graph); - -// Convert a xla session module to a map to XlaNode from unique id -xla::StatusOr> -ConvertXlaSessionModuleToXlaNodes(const xla::SessionModule& session_module); - -} // namespace xla_tf_graph -} // namespace tensorflow - -#endif // TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_ diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc deleted file mode 100644 index 144269303ee140bb7a9a30133a5d88b41b4f4273..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc +++ /dev/null @@ -1,134 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/function_ops.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace xla_tf_graph { - -static std::unique_ptr BuildAddGraph() { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); - auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); - // See tf2xla/kernels/binary_ops.cc - auto c = ops::Add(scope.WithOpName("C"), a, b); - auto d = ops::_Retval(scope.WithOpName("D"), c, 0); - std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_CHECK_OK(scope.ToGraph(graph.get())); - return graph; -} - -static std::vector BuildAddGraphArguments() { - // Builds a description of the arguments. - std::vector args(2); - args[0].kind = XlaCompiler::Argument::kParameter; - args[0].type = DT_INT32; - // Difference of dimension will add extra broadcast_dimensions. - // broadcast_dimension generates an additional HloInstruction - // in user_computation.cc - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2, 2}); - args[1].kind = XlaCompiler::Argument::kParameter; - args[1].type = DT_INT32; - args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); - return args; -} - -// CAVEAT: Debug purpose only. -// This function dumps a protobuf string format of HloModule. -static void DumpHloGraphForDebug(const std::vector& args, - std::unique_ptr graph) { - std::unique_ptr flib_def; - std::unique_ptr flr; - std::unique_ptr compiler; - - xla::Client* client = xla::ClientLibrary::LocalClientOrDie(); - XlaOpRegistry::RegisterCompilationKernels(); - - FunctionDefLibrary flib; - flib_def.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); - - // Compiles the graph. - XlaCompiler::Options options; - DeviceType device_type("XLA_CPU_JIT"); - options.device_type = &device_type; - options.client = client; - options.flib_def = flib_def.get(); - compiler.reset(new XlaCompiler(options)); - - // Compile graph - XlaCompiler::CompilationResult result; - TF_CHECK_OK(compiler->CompileGraph(XlaCompiler::CompileOptions(), "dump", - std::move(graph), args, &result)); - - // Convert to hlo - xla::Computation& computation = *result.computation; - - xla::Service* service( - static_cast(xla::ClientLibrary::GetXlaService( - static_cast(client)->platform()))); - const xla::ComputationTracker& computation_tracker = - service->computation_tracker(); - - auto user_computation_status = - computation_tracker.Resolve(computation.handle()); - TF_CHECK_OK(user_computation_status.status()); - auto user_computation = user_computation_status.ConsumeValueOrDie(); - xla::VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - std::unique_ptr hlo_module = - std::move(computation_tracker - .BuildHloModule(versioned_handle, xla::HloModuleConfig()) - .ValueOrDie()); - VLOG(1) << "--- DUMP HLO ---"; - VLOG(1) << hlo_module->ToString(); -} - -TEST(XlaTfGraphUtil, ConvertTfGraphToSessionModule) { - // Builds a description of the arguments. - std::vector args = BuildAddGraphArguments(); - std::unique_ptr graph = BuildAddGraph(); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr session_module, - ConvertTfGraphToXlaSessionModule(args, std::move(graph))); - - ASSERT_EQ(4, session_module->entry().requests_size()); - - VLOG(1) << "--- DUMP ---"; - VLOG(1) << session_module->DebugString(); - DumpHloGraphForDebug(args, BuildAddGraph()); -} - -TEST(XlaTfGraphUtil, ConvertXlaSessionModuleToXlaNodes) { - std::vector args = BuildAddGraphArguments(); - std::unique_ptr graph = BuildAddGraph(); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr session_module, - ConvertTfGraphToXlaSessionModule(args, std::move(graph))); - TF_ASSERT_OK_AND_ASSIGN(auto xla_nodes, - ConvertXlaSessionModuleToXlaNodes(*session_module)); - EXPECT_EQ(session_module->entry().requests_size(), xla_nodes.size()); -} - -} // namespace xla_tf_graph -} // namespace tensorflow diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index aaede2a6bb223dcdb2f70231c9cebe20e68e6b64..1c58aa3315bb88eeb69035c11f56ddfd3d651eee 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -163,6 +163,7 @@ CORE_PROTO_SRCS = [ "framework/function.proto", "framework/graph.proto", "framework/graph_transfer_info.proto", + "framework/iterator.proto", "framework/kernel_def.proto", "framework/log_memory.proto", "framework/node_def.proto", @@ -510,6 +511,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":lib", + ":lib_internal", ":op_gen_overrides_proto_cc", ":protos_all_cc", ], @@ -651,14 +653,15 @@ cc_library( ":image_ops_op_lib", ":io_ops_op_lib", ":linalg_ops_op_lib", - ":lookup_ops_op_lib", ":logging_ops_op_lib", + ":lookup_ops_op_lib", ":math_ops_op_lib", ":nn_ops_op_lib", ":no_op_op_lib", ":parsing_ops_op_lib", ":random_ops_op_lib", ":remote_fused_graph_ops_op_lib", + ":resource_variable_ops_op_lib", ":script_ops_op_lib", ":sdca_ops_op_lib", ":sendrecv_ops_op_lib", @@ -780,6 +783,7 @@ cc_library( "//tensorflow/core/kernels:dataset_ops", "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:histogram_op", "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", "//tensorflow/core/kernels:linalg", @@ -888,6 +892,7 @@ cc_library( ":test", "//tensorflow/cc:scope", "//tensorflow/core/kernels:constant_op", + "//tensorflow/core/kernels:ops_testutil", "//tensorflow/core/kernels:ops_util", ], ) @@ -1396,6 +1401,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ "platform/platform.h", "platform/protobuf_internal.h", "platform/setround.h", + "platform/snappy.h", "platform/tensor_coding.h", "platform/tracing.h", ] @@ -1406,7 +1412,7 @@ cc_library( hdrs = LIB_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), defines = tf_additional_lib_defines() + [ - "SNAPPY", + "TF_USE_SNAPPY", ] + tf_additional_verbs_lib_defines() + tf_additional_mpi_lib_defines() + tf_additional_gdr_lib_defines(), @@ -1938,6 +1944,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ tf_cuda_library( name = "core_cpu_impl", srcs = [ + "common_runtime/accumulate_n_optimizer.cc", "common_runtime/allocator_retry.cc", "common_runtime/bfc_allocator.cc", "common_runtime/build_graph_options.cc", @@ -2121,6 +2128,7 @@ GPU_RUNTIME_HEADERS = [ "common_runtime/gpu/gpu_debug_allocator.h", "common_runtime/gpu/gpu_device.h", "common_runtime/gpu/gpu_init.h", + "common_runtime/gpu/gpu_managed_allocator.h", "common_runtime/gpu/gpu_stream_util.h", "common_runtime/gpu/gpu_util.h", "common_runtime/gpu/pool_allocator.h", @@ -2135,6 +2143,7 @@ tf_cuda_library( "common_runtime/gpu/gpu_debug_allocator.cc", "common_runtime/gpu/gpu_device.cc", "common_runtime/gpu/gpu_device_factory.cc", + "common_runtime/gpu/gpu_managed_allocator.cc", "common_runtime/gpu/gpu_stream_util.cc", "common_runtime/gpu/gpu_util.cc", "common_runtime/gpu/gpu_util_platform_specific.cc", @@ -2171,6 +2180,7 @@ tf_cuda_library( ":lib", ":lib_internal", ":protos_all_cc", + ":stream_executor", "//third_party/eigen3", ] + if_static([":gpu_runtime_impl"]), ) @@ -2253,7 +2263,6 @@ cc_library( "lib/io/block_builder.h", "lib/io/format.h", "lib/random/philox_random_test_utils.h", - "platform/snappy.h", ], deps = [ ":lib", @@ -2500,6 +2509,7 @@ tf_cc_test( srcs = ["framework/op_gen_lib_test.cc"], deps = [ ":op_gen_lib", + ":protos_all_cc", ":test", ":test_main", ], @@ -2666,6 +2676,22 @@ tf_cc_tests( ], ) +tf_cc_test_mkl( + name = "mkl_runtime_tests", + size = "small", + srcs = ["common_runtime/mkl_cpu_allocator_test.cc"], + linkstatic = 1, + deps = [ + ":core", + ":core_cpu", + ":framework", + ":framework_internal", + ":test", + ":test_main", + ":testlib", + ], +) + tf_cc_test_mkl( name = "mkl_related_tests", size = "small", @@ -2693,7 +2719,20 @@ tf_cc_test_mkl( "//tensorflow/cc:sendrecv_ops", "//tensorflow/core/kernels:ops_util", "//third_party/eigen3", - ], + ] + if_mkl([ + "//tensorflow/core/kernels:mkl_aggregate_ops", + "//tensorflow/core/kernels:mkl_concat_op", + "//tensorflow/core/kernels:mkl_conv_op", + "//tensorflow/core/kernels:mkl_cwise_ops_common", + "//tensorflow/core/kernels:mkl_fused_batch_norm_op", + "//tensorflow/core/kernels:mkl_identity_op", + "//tensorflow/core/kernels:mkl_input_conversion_op", + "//tensorflow/core/kernels:mkl_lrn_op", + "//tensorflow/core/kernels:mkl_pooling_ops", + "//tensorflow/core/kernels:mkl_relu_op", + "//tensorflow/core/kernels:mkl_reshape_op", + "//tensorflow/core/kernels:mkl_tfconv_op", + ]), ) tf_cc_tests_gpu( @@ -3315,6 +3354,41 @@ tf_cc_test( ], ) +filegroup( + name = "base_api_def", + data = glob(["api_def/base_api/*"]), +) + +filegroup( + name = "python_api_def", + data = glob(["api_def/python_api/*"]), +) + +tf_cc_test( + name = "api_test", + srcs = ["api_def/api_test.cc"], + data = [ + ":base_api_def", + "//tensorflow/cc:ops/op_gen_overrides.pbtxt", + ], + tags = [ + "manual", + "notap", + ], + deps = [ + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":lib_test_internal", + ":op_gen_lib", + ":op_gen_overrides_proto_cc", + ":ops", + ":protos_all_cc", + ":test", + ], +) + tf_cc_test_gpu( name = "gpu_tracer_test", size = "small", diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ceeb172fa0a9abf2ab7adcfc801b4bcb5fa04381 --- /dev/null +++ b/tensorflow/core/api_def/api_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Test that verifies tensorflow/core/api_def/base_api/api_def*.pbtxt files +// are correct. If api_def*.pbtxt do not match expected contents, run +// tensorflow/core/api_def/base_api/update_api_def.sh script to update them. + +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/framework/op_gen_overrides.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { +constexpr char kDefaultApiDefDir[] = + "tensorflow/core/api_def/base_api"; +constexpr char kOverridesFilePath[] = + "tensorflow/cc/ops/op_gen_overrides.pbtxt"; +constexpr char kApiDefFileFormat[] = "api_def_%c.pbtxt"; +constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + +// Get map from first character to ApiDefs for ops +// that start with that character. +std::unordered_map GenerateApiDef( + const OpList& ops, const OpGenOverrides& overrides) { + std::unordered_map name_to_override; + for (const auto& op_override : overrides.op()) { + name_to_override[op_override.name()] = op_override; + } + + std::unordered_map api_defs_map; + + for (const auto& op : ops.op()) { + CHECK(!op.name().empty()) + << "Encountered empty op name: %s" << op.DebugString(); + const char file_id = toupper(op.name()[0]); + CHECK(isalpha(file_id)) << "Unexpected op name: " << op.name(); + ApiDef* api_def = api_defs_map[file_id].add_op(); + api_def->set_graph_op_name(op.name()); + + if (name_to_override.find(op.name()) != name_to_override.end()) { + const auto& op_override = name_to_override[op.name()]; + // Set visibility + if (op_override.skip()) { + api_def->set_visibility(ApiDef_Visibility_SKIP); + } else if (op_override.hide()) { + api_def->set_visibility(ApiDef_Visibility_HIDDEN); + } + // Add endpoints + if (!op_override.rename_to().empty()) { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(op_override.rename_to()); + } else { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(op.name()); + } + for (auto& alias : op_override.alias()) { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(alias); + } + // Add attributes + for (auto& attr : op.attr()) { + auto* api_def_attr = api_def->add_attr(); + api_def_attr->set_name(attr.name()); + for (auto& attr_override : op_override.attr_default()) { + if (attr.name() == attr_override.name()) { + *(api_def_attr->mutable_default_value()) = attr_override.value(); + } + } + for (auto& attr_rename : op_override.attr_rename()) { + if (attr.name() == attr_rename.from()) { + api_def_attr->set_rename_to(attr_rename.to()); + } + } + } + } else { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(op.name()); + } + // Add docs + api_def->set_summary(op.summary()); + api_def->set_description(op.description()); + } + return api_defs_map; +} + +// Reads golden api defs file with the given suffix. +string GetGoldenApiDefsStr(Env* env, const string& api_files_dir, char suffix) { + string file_path = strings::Printf( + io::JoinPath(api_files_dir, kApiDefFileFormat).c_str(), suffix); + if (env->FileExists(file_path).ok()) { + string file_contents; + TF_EXPECT_OK(ReadFileToString(env, file_path, &file_contents)); + return file_contents; + } + return ""; +} + +void RunApiTest(bool update_api_def, const string& api_files_dir) { + // Read C++ overrides file + string overrides_file_contents; + Env* env = Env::Default(); + TF_EXPECT_OK( + ReadFileToString(env, kOverridesFilePath, &overrides_file_contents)); + + // Read all ops + OpList ops; + OpRegistry::Global()->Export(false, &ops); + const std::vector multi_line_fields = {"description"}; + + // Get expected ApiDefs + OpGenOverrides overrides; + auto new_api_defs_map = GenerateApiDef(ops, overrides); + + bool updated_at_least_one_file = false; + + for (char c : kAlphabet) { + string golden_api_defs_str = GetGoldenApiDefsStr(env, api_files_dir, c); + string new_api_defs_str = new_api_defs_map[c].DebugString(); + new_api_defs_str = PBTxtToMultiline(new_api_defs_str, multi_line_fields); + if (golden_api_defs_str == new_api_defs_str) { + continue; + } + if (update_api_def) { + string output_file_path = + io::JoinPath(api_files_dir, strings::Printf(kApiDefFileFormat, c)); + if (new_api_defs_str.empty()) { + std::cout << "Deleting " << output_file_path << "..." << std::endl; + TF_EXPECT_OK(env->DeleteFile(output_file_path)); + } else { + std::cout << "Updating " << output_file_path << "..." << std::endl; + TF_EXPECT_OK( + WriteStringToFile(env, output_file_path, new_api_defs_str)); + } + updated_at_least_one_file = true; + } else { + EXPECT_EQ(golden_api_defs_str, new_api_defs_str) + << "To update golden API files, run " + << "tensorflow/core/api_def/update_api_def.sh."; + } + } + + if (update_api_def && !updated_at_least_one_file) { + std::cout << "Api def files are already up to date." << std::endl; + } +} + +TEST(ApiTest, GenerateBaseAPIDef) { RunApiTest(false, kDefaultApiDefDir); } +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { + bool update_api_def = false; + tensorflow::string api_files_dir = tensorflow::kDefaultApiDefDir; + std::vector flag_list = { + tensorflow::Flag( + "update_api_def", &update_api_def, + "Whether to update tensorflow/core/api_def/base_api/api_def*.pbtxt " + "files if they differ from expected API."), + tensorflow::Flag("api_def_dir", &api_files_dir, + "Base directory of api_def*.pbtxt files.")}; + std::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parsed_values_ok) { + std::cerr << usage << std::endl; + return 2; + } + if (update_api_def) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + tensorflow::RunApiTest(update_api_def, api_files_dir); + return 0; + } + testing::InitGoogleTest(&argc, argv); + // Run tests + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/core/api_def/base_api/api_def_A.pbtxt b/tensorflow/core/api_def/base_api/api_def_A.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8193d1bc624535c7894430284686e8664fb71a2d --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_A.pbtxt @@ -0,0 +1,670 @@ +op { + graph_op_name: "Abort" + endpoint { + name: "Abort" + } + summary: "Raise a exception to abort the process when called." + description: <= 2." +} +op { + graph_op_name: "AdjustContrastv2" + endpoint { + name: "AdjustContrastv2" + } + summary: "Adjust the contrast of one or more images." + description: < [2.0132, 1.056] +``` + +@compatibility(numpy) +Equivalent to np.angle. +@end_compatibility +END +} +op { + graph_op_name: "Any" + endpoint { + name: "Any" + } + summary: "Computes the \"logical or\" of elements across dimensions of a tensor." + description: < l1 else 0.0 +accum = accum_new +END +} +op { + graph_op_name: "ApplyFtrlV2" + endpoint { + name: "ApplyFtrlV2" + } + summary: "Update \'*var\' according to the Ftrl-proximal scheme." + description: < l1 else 0.0 +accum = accum_new +END +} +op { + graph_op_name: "ApplyGradientDescent" + endpoint { + name: "ApplyGradientDescent" + } + summary: "Update \'*var\' by subtracting \'alpha\' * \'delta\' from it." +} +op { + graph_op_name: "ApplyMomentum" + endpoint { + name: "ApplyMomentum" + } + summary: "Update \'*var\' according to the momentum scheme. Set use_nesterov = True if you" + description: < threshold`) +or and `false` otherwise. + +This operation is useful for Locality-Sensitive-Hashing (LSH) and other +algorithms that use hashing approximations of cosine and `L2` distances; +codes can be generated from an input via: + +```python +codebook_size = 50 +codebook_bits = codebook_size * 32 +codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], + dtype=x.dtype, + initializer=tf.orthogonal_initializer()) +codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) +codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 +# now codes has shape x.shape[:-1] + [codebook_size] +``` + +**NOTE**: Currently, the innermost dimension of the tensor must be divisible +by 8. + +Given an `input` shaped `[s0, s1, ..., s_n]`, the output is +a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. +END +} +op { + graph_op_name: "Complex" + endpoint { + name: "Complex" + } + summary: "Converts two real numbers to a complex number." + description: < [[2.25 + 4.75j], [3.25 + 5.75j]] +``` +END +} +op { + graph_op_name: "ComplexAbs" + endpoint { + name: "ComplexAbs" + } + summary: "Computes the complex absolute value of a tensor." + description: < [0, 0, 0], [0, 2, 0], [0, 5, 0] +``` + +This is typically used by gradient computations for a concat operation. +END +} +op { + graph_op_name: "ConcatV2" + endpoint { + name: "ConcatV2" + } + summary: "Concatenates tensors along one dimension." +} +op { + graph_op_name: "ConcatenateDataset" + endpoint { + name: "ConcatenateDataset" + } + summary: "Creates a dataset that concatenates `input_dataset` with `another_dataset`." +} +op { + graph_op_name: "ConditionalAccumulator" + endpoint { + name: "ConditionalAccumulator" + } + summary: "A conditional accumulator for aggregating gradients." + description: < [-2.25 - 4.75j, 3.25 - 5.75j] +``` +END +} +op { + graph_op_name: "Const" + endpoint { + name: "Const" + } + summary: "Returns a constant tensor." +} +op { + graph_op_name: "ControlTrigger" + endpoint { + name: "ControlTrigger" + } + summary: "Does nothing. Serves as a control trigger for scheduling." + description: < [a, a * b, a * b * c] +``` + +By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +performed instead: + +```python +tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] +``` + +By setting the `reverse` kwarg to `True`, the cumprod is performed in the +opposite direction: + +```python +tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] +``` + +This is more efficient than using separate `tf.reverse` ops. + +The `reverse` and `exclusive` kwargs can also be combined: + +```python +tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +``` +END +} +op { + graph_op_name: "Cumsum" + endpoint { + name: "Cumsum" + } + summary: "Compute the cumulative sum of the tensor `x` along `axis`." + description: < [a, a + b, a + b + c] +``` + +By setting the `exclusive` kwarg to `True`, an exclusive cumsum is +performed instead: + +```python +tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] +``` + +By setting the `reverse` kwarg to `True`, the cumsum is performed in the +opposite direction: + +```python +tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] +``` + +This is more efficient than using separate `tf.reverse` ops. + +The `reverse` and `exclusive` kwargs can also be combined: + +```python +tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] +``` +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_D.pbtxt b/tensorflow/core/api_def/base_api/api_def_D.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ff8a7223c7223f5c4b72ffc7154b7fc77d8eeb06 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_D.pbtxt @@ -0,0 +1,790 @@ +op { + graph_op_name: "DebugGradientIdentity" + endpoint { + name: "DebugGradientIdentity" + } + summary: "Identity op for gradient debugging." + description: <