diff --git a/.gitignore b/.gitignore index 09734fe4974935956fd599f7f86cd5c4d195d5e2..9572a3e97c4ebf61211cbfe3af594d20606eda72 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ cmake_build/ .idea/** /build/ /tensorflow/core/util/version_info.cc +/tensorflow/python/framework/fast_tensor_util.cpp diff --git a/README.md b/README.md index 6339c57c95032d490c8acf02f21967d4ee35ddb9..24bbb6cec10e16c7b6ae37b7cf8b6f90ebe5e5dd 100644 --- a/README.md +++ b/README.md @@ -38,10 +38,11 @@ People who are a little more adventurous can also try our nightly binaries: **Nightly pip packages** * We are pleased to announce that TensorFlow now offers nightly pip packages -under the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) project on pypi. -Simply run `pip install tf-nightly` in a clean environment to install the nightly -tensorflow build. We currently only support CPU packages on Linux, Mac, and Windows. -GPU packages on all platforms will arrive soon! +under the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) and +[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on pypi. +Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean +environment to install the nightly TensorFlow build. We support CPU and GPU +packages on Linux, Mac, and Windows. **Individual whl files** diff --git a/RELEASE.md b/RELEASE.md index 3d497dbaa965d2cf239cab8360109bf5804b6f6e..d30ee69f40e24672f1d68f81109e5d9bd266e81d 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,6 +1,50 @@ # Release 1.4.0 ## Major Features And Improvements +* `tf.keras` is now part of the core TensorFlow API. +* `tf.data` is now part of the core TensorFlow API. + * The API is now subject to backwards compatibility guarantees. + * For a guide to migrating from the `tf.contrib.data` API, see the + [README](https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/contrib/data/README.md). + * Major new features include `Dataset.from_generator()` (for building an input + pipeline from a Python generator), and the `Dataset.apply()` method for + applying custom transformation functions. + * Several custom transformation functions have been added, including + `tf.contrib.data.batch_and_drop_remainder()` and + `tf.contrib.data.sloppy_interleave()`. +* Add `train_and_evaluate` for simple distributed `Estimator` training. +* Add `tf.spectral.dct` for computing the DCT-II. +* Add Mel-Frequency Cepstral Coefficient support to `tf.contrib.signal` + (with GPU and gradient support). +* Add a self-check on `import tensorflow` for Windows DLL issues. +* Add NCHW support to `tf.depth_to_space` on GPU. +* SinhArcsinh (scalar) distribution added to `contrib.distributions`. +* Make `GANEstimator` opensource. +* `Estimator.export_savedmodel()` now includes all valid serving signatures + that can be constructed from the Serving Input Receiver and all available + ExportOutputs. For instance, a classifier may provide regression- and + prediction-flavored outputs, in addition to the classification-flavored one. + Building signatures from these allows TF Serving to honor requests using the + different APIs (Classify, Regress, and Predict). Furthermore, + `serving_input_receiver_fn()` may now specify alternative subsets of nodes + that may act as inputs. This allows, for instance, producing a prediction + signature for a classifier that accepts raw `Tensors` instead of a serialized + `tf.Example`. +* Add `tf.contrib.bayesflow.hmc`. +* Add `tf.contrib.distributions.MixtureSameFamily`. +* Make `Dataset.shuffle()` always reshuffles after each iteration by default. +* Add `tf.contrib.bayesflow.metropolis_hastings`. +* Add `log_rate` parameter to `tf.contrib.distributions.Poisson`. +* Extend `tf.contrib.distributions.bijector` API to handle some non-injective + transforms. +* Java: + * Generics (e.g., `Tensor`) for improved type-safety + (courtesy @andrewcmyers). + * Support for multi-dimensional string tensors. + * Support loading of custom operations (e.g. many in `tf.contrib`) on Linux + and OS X +* All our prebuilt binaries have been built with CUDA 8 and cuDNN 6. + We anticipate releasing TensorFlow 1.5 with CUDA 9 and cuDNN 7. ## Bug Fixes and Other Changes * `tf.nn.rnn_cell.DropoutWrapper` is now more careful about dropping out LSTM @@ -12,6 +56,57 @@ * Removed `tf.contrib.training.python_input`. The same behavior, in a more flexible and reproducible package, is available via the new `tf.contrib.data.Dataset.from_generator` method! +* Fix `tf.contrib.distributions.Affine` incorrectly computing log-det-jacobian. +* Fix `tf.random_gamma` incorrectly handling non-batch, scalar draws. +* Resolved a race condition in TensorForest TreePredictionsV4Op. +* Google Cloud Storage file system and Hadoop file system support are now + default build options. +* Custom op libraries must link against libtensorflow_framework.so + (installed at `tf.sysconfig.get_lib()`). + +## Breaking Changes to the API +* The signature of the `tf.contrib.data.rejection_resample()` function has been + changed. It now returns a function that can be used as an argument to + `Dataset.apply()`. +* Remove `tf.contrib.data.Iterator.from_dataset()` method. Use + `Dataset.make_initializable_iterator()` instead. +* Remove seldom used and unnecessary `tf.contrib.data.Iterator.dispose_op()`. +* Reorder some TFGAN loss functions in a non-backwards compatible way. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Abdullah Alrasheed, abenmao, Adam Salvail, Aditya Dhulipala, Ag Ramesh, +Akimasa Kimura, Alan Du, Alan Yee, Alexander, Amit Kushwaha, Amy, Andrei Costinescu, +Andrei Nigmatulin, Andrew Erlichson, Andrew Myers, Andrew Stepanov, Androbin, AngryPowman, +Anish Shah, Anton Daitche, Artsiom Chapialiou, asdf2014, Aseem Raj Baranwal, Ash Hall, +Bart Kiers, Batchu Venkat Vishal, ben, Ben Barsdell, Bill Piel, Carl Thomé, Catalin Voss, +Changming Sun, Chengzhi Chen, Chi Zeng, Chris Antaki, Chris Donahue, Chris Oelmueller, +Chris Tava, Clayne Robison, Codrut, Courtial Florian, Dalmo Cirne, Dan J, Darren Garvey, +David Kristoffersson, David Norman, David RöThlisberger, DavidNorman, Dhruv, DimanNe, +Dorokhov, Duncan Mac-Vicar P, EdwardDixon, EMCP, error.d, FAIJUL, Fan Xia, +Francois Xavier, Fred Reiss, Freedom" Koan-Sin Tan, Fritz Obermeyer, Gao, Xiang, +Guenther Schmuelling, Guo Yejun (郭叶军), Hans Gaiser, HectorSVC, Hyungsuk Yoon, +James Pruegsanusak, Jay Young, Jean Wanka, Jeff Carpenter, Jeremy Rutman, Jeroen BéDorf, +Jett Jones, Jimmy Jia, jinghuangintel, jinze1994, JKurland, Joel Hestness, joetoth, +John B Nelson, John Impallomeni, John Lawson, Jonas, Jonathan Dekhtiar, joshkyh, Jun Luan, +Jun Mei, Kai Sasaki, Karl Lessard, karl@kubx.ca, Kb Sriram, Kenichi Ueno, Kevin Slagle, +Kongsea, Lakshay Garg, lhlmgr, Lin Min, liu.guangcong, Loki Der Quaeler, Louie Helm, +lucasmoura, Luke Iwanski, Lyndon White, Mahmoud Abuzaina, Marcel Puyat, Mark Aaron Shirley, +Michele Colombo, MtDersvan, Namrata-Ibm, Nathan Luehr, Naurril, Nayana Thorat, Nicolas Lopez, +Niranjan Hasabnis, Nolan Liu, Nouce, Oliver Hennigh, osdamv, Patrik Erdes, +Patryk Chrabaszcz, Pavel Christof, Penghao Cen, postBG, Qingqing Cao, Qingying Chen, qjivy, +Raphael, Rasmi, raymondxyang, Renze Yu, resec, Roffel, Ruben Vereecken, Ryohei Kuroki, +sandipmgiri, Santiago Castro, Scott Kirkland, Sean Vig, Sebastian Raschka, Sebastian Weiss, +Sergey Kolesnikov, Sergii Khomenko, Shahid, Shivam Kotwalia, Stuart Berg, Sumit Gouthaman, +superzerg, Sven Mayer, tetris, Ti Zhou, Tiago Freitas Pereira, Tian Jin, Tomoaki Oiki, +Vaibhav Sood, vfdev, Vivek Rane, Vladimir Moskva, wangqr, Weber Xie, Will Frey, +Yan Facai (颜发才), yanivbl6, Yaroslav Bulatov, Yixing Lao, Yong Tang, youkaichao, +Yuan (Terry) Tang, Yue Zhang, Yuxin Wu, Ziming Dong, ZxYuan, 黄璞 + +We are also grateful to all who filed issues or helped resolve them, asked and +answered questions, and were part of inspiring discussions. # Release 1.3.0 diff --git a/WORKSPACE b/WORKSPACE index 32d3d94ec232a5bf0eb0092b5d04df8440127408..1bf1069f8801c9d135d77c871520ff733b7713e9 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "25f5399f18d8bf9ce435f85c6bbf671ec4820bc4396b3022cc5dc4bc66303609", - strip_prefix = "rules_closure-0.4.2", + sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257", + strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1", urls = [ - "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz", # 2017-08-29 - "https://github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz", + "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28 ], ) diff --git a/configure.py b/configure.py index df2c74d23d8ea306028c8c0406c5475d31fa884f..95835e538b62371d671aa7adb0f2f12b71639a58 100644 --- a/configure.py +++ b/configure.py @@ -30,7 +30,8 @@ try: except ImportError: from distutils.spawn import find_executable as which -_TF_BAZELRC = '.tf_configure.bazelrc' +_TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), + '.tf_configure.bazelrc') _DEFAULT_CUDA_VERSION = '8.0' _DEFAULT_CUDNN_VERSION = '6' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' @@ -975,6 +976,7 @@ def main(): run_gen_git_source(environ_cp) if is_windows(): + environ_cp['TF_NEED_S3'] = '0' environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0' @@ -987,9 +989,11 @@ def main(): set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', 'with_jemalloc', True) set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform', - 'with_gcp_support', False, 'gcp') + 'with_gcp_support', True, 'gcp') set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System', - 'with_hdfs_support', False, 'hdfs') + 'with_hdfs_support', True, 'hdfs') + set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', + 'with_s3_support', True, 's3') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', False, 'xla') set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 272931d458c15a91437647f9a4c50f73a50a9a9d..fa5da5fdbb4442ea1b971623ea1447ddd2e8f4d6 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -120,6 +120,15 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "ios_x86_64", + values = { + "cc_target_os": "apple", + "cpu": "ios_x86_64", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "linux_x86_64", values = {"cpu": "k8"}, @@ -185,6 +194,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_s3_support", + values = {"define": "with_s3_support=true"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_xla_support", values = {"define": "with_xla_support=true"}, @@ -276,8 +291,8 @@ config_setting( package_group( name = "internal", packages = [ - "//learning/protonn/llgtm/...", "//tensorflow/...", + "//tensorflow_fold/llgtm/...", ], ) @@ -400,6 +415,7 @@ filegroup( "//tensorflow/contrib/nn:all_files", "//tensorflow/contrib/opt:all_files", "//tensorflow/contrib/predictor:all_files", + "//tensorflow/contrib/quantize:all_files", "//tensorflow/contrib/receptive_field:all_files", "//tensorflow/contrib/reduce_slice_ops:all_files", "//tensorflow/contrib/remote_fused_graph/pylib:all_files", @@ -454,6 +470,7 @@ filegroup( "//tensorflow/core/kernels/fuzzing:all_files", "//tensorflow/core/kernels/hexagon:all_files", "//tensorflow/core/kernels/neon:all_files", + "//tensorflow/core/lib/db:all_files", "//tensorflow/core/ops/compat:all_files", "//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/default/build_config:all_files", @@ -491,7 +508,9 @@ filegroup( "//tensorflow/python/keras:all_files", "//tensorflow/python/kernel_tests:all_files", "//tensorflow/python/kernel_tests/distributions:all_files", + "//tensorflow/python/kernel_tests/linalg:all_files", "//tensorflow/python/ops/distributions:all_files", + "//tensorflow/python/ops/linalg:all_files", "//tensorflow/python/profiler:all_files", "//tensorflow/python/profiler/internal:all_files", "//tensorflow/python/saved_model:all_files", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 6919dfe71124ebe861c1649b123e8a714056e45d..ef7eb5a4d16b29aecc34f33cb41dd7cf9450c5f2 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -30,7 +30,10 @@ tf_cuda_library( name = "c_api_internal", srcs = ["c_api.h"], hdrs = ["c_api_internal.h"], - visibility = ["//tensorflow/c:__subpackages__"], + visibility = [ + "//tensorflow:internal", + "//tensorflow/c:__subpackages__", + ], deps = select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 334f867e47800507760eaa71dce91186f646f72d..79fbd8c90c8ce938824925ac931e579fc4c823d7 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -1854,18 +1854,18 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, return; } const int last_node_id = graph->graph.num_node_ids(); - std::vector> return_outputs_vec; - status->status = tensorflow::ImportGraphDef( - opts->opts, def, &graph->graph, &graph->refiner, &return_outputs_vec); + tensorflow::ImportGraphDefResults results; + status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, + &graph->refiner, &results); if (!status->status.ok()) return; for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { auto* node = graph->graph.FindNodeId(i); if (node != nullptr) graph->name_map[node->name()] = node; } - DCHECK_EQ(return_outputs_vec.size(), num_return_outputs); + DCHECK_EQ(results.return_tensors.size(), num_return_outputs); for (int i = 0; i < num_return_outputs; ++i) { - return_outputs[i].oper = ToOperation(return_outputs_vec[i].first); - return_outputs[i].index = return_outputs_vec[i].second; + return_outputs[i].oper = ToOperation(results.return_tensors[i].first); + return_outputs[i].index = results.return_tensors[i].second; } } @@ -1945,11 +1945,11 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, } // TOOD(skyewm): change to OutputTensor - std::vector> return_tensors; + tensorflow::ImportGraphDefResults results; TF_RETURN_IF_ERROR( - ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &return_tensors)); + ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results)); - for (const auto& pair : return_tensors) { + for (const auto& pair : results.return_tensors) { return_nodes->emplace_back(pair.first, pair.second); } return Status::OK(); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index db94828e1a8adb5af7c34500c0675b1a6e93b805..76cfcd5e0d44e92126ac99075842ebdb8d5bc145 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -118,6 +118,8 @@ typedef enum TF_DataType { TF_HALF = 19, TF_RESOURCE = 20, TF_VARIANT = 21, + TF_UINT32 = 22, + TF_UINT64 = 23, } TF_DataType; // TF_DataTypeSize returns the sizeof() for the underlying type corresponding @@ -1144,7 +1146,7 @@ TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef( const void* proto, size_t proto_len, TF_Status* status); // Sets function attribute named `attr_name` to value stored in `proto`. -// If this attribute is already set to another value, it is overriden. +// If this attribute is already set to another value, it is overridden. // `proto` should point to a sequence of bytes of length `proto_len` // representing a binary serialization of an AttrValue protocol // buffer. diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index e7b9bca5b50e4837534c315b8fa2ca161019d100..b1f7bdaa5420a56386e6983052df20aa976aa867 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/checkpoint_reader.h" #include +#include #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -24,43 +25,43 @@ limitations under the License. #include "tensorflow/core/util/saved_tensor_slice_util.h" namespace tensorflow { - namespace checkpoint { class TensorSliceReader; CheckpointReader::CheckpointReader(const string& filename, TF_Status* out_status) - : reader_(nullptr), v2_reader_(nullptr), var_to_shape_map_ptr_(nullptr) { + : reader_(nullptr), + v2_reader_(nullptr), + var_to_shape_map_(nullptr), + var_to_data_type_map_(nullptr) { // Depending on whether this is a V2 ckpt, initializes "reader_" or // "v2_reader_". std::vector v2_path; if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() && !v2_path.empty()) { - v2_reader_ = - new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */); + v2_reader_.reset( + new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */)); if (!v2_reader_->status().ok()) { Set_TF_Status_from_Status(out_status, v2_reader_->status()); return; } - var_to_shape_map_ptr_ = BuildV2VarToShapeMap(); + auto result = BuildV2VarMaps(); + var_to_shape_map_.swap(result.first); + var_to_data_type_map_.swap(result.second); } else { - reader_ = new TensorSliceReader(filename); + reader_.reset(new TensorSliceReader(filename)); if (!reader_->status().ok()) { Set_TF_Status_from_Status(out_status, reader_->status()); return; } - var_to_shape_map_ptr_ = - new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()); + var_to_shape_map_.reset( + new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap())); + var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap( + reader_->GetVariableToDataTypeMap())); } } -CheckpointReader::~CheckpointReader() { - delete var_to_shape_map_ptr_; - delete reader_; - delete v2_reader_; -} - bool CheckpointReader::HasTensor(const string& name) const { if (reader_ != nullptr) { return reader_->HasTensor(name, nullptr, nullptr); @@ -70,8 +71,14 @@ bool CheckpointReader::HasTensor(const string& name) const { const TensorSliceReader::VarToShapeMap& CheckpointReader::GetVariableToShapeMap() const { - CHECK(var_to_shape_map_ptr_); - return *var_to_shape_map_ptr_; + CHECK(var_to_shape_map_); + return *var_to_shape_map_; +} + +const TensorSliceReader::VarToDataTypeMap& +CheckpointReader::GetVariableToDataTypeMap() const { + CHECK(var_to_data_type_map_); + return *var_to_data_type_map_; } const string CheckpointReader::DebugString() const { @@ -100,7 +107,9 @@ void CheckpointReader::GetTensor( } } -TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() { +std::pair, + std::unique_ptr> +CheckpointReader::BuildV2VarMaps() { CHECK(v2_reader_ != nullptr); CHECK(v2_reader_->status().ok()); @@ -123,18 +132,23 @@ TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() { } // Second pass: adds the entries, ignoring the filtered keys. - TensorSliceReader::VarToShapeMap* var_to_shape_map = - new TensorSliceReader::VarToShapeMap; + std::unique_ptr var_to_shape_map( + new TensorSliceReader::VarToShapeMap); + std::unique_ptr var_to_data_type_map( + new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - (*var_to_shape_map)[v2_reader_->key().ToString()] = - TensorShape(entry.shape()); + string key = v2_reader_->key().ToString(); + (*var_to_shape_map)[key] = TensorShape(entry.shape()); + (*var_to_data_type_map)[key] = DataType(entry.dtype()); } - return var_to_shape_map; // Owned by caller. + // The returned pointers are owned by the caller. + return std::make_pair(std::move(var_to_shape_map), + std::move(var_to_data_type_map)); } } // namespace checkpoint diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h index 1124416380df624f97b3ce2ebaadb04b3c17d341..4de1300a7f66a8b4eb8074819432fd7dd597bb15 100644 --- a/tensorflow/c/checkpoint_reader.h +++ b/tensorflow/c/checkpoint_reader.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_C_CHECKPOINT_READER_H #define TENSORFLOW_C_CHECKPOINT_READER_H +#include +#include + #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" @@ -24,7 +27,6 @@ limitations under the License. #include "tensorflow/core/util/tensor_slice_reader.h" namespace tensorflow { - namespace checkpoint { class TensorSliceReader; @@ -38,15 +40,18 @@ class TensorSliceReader; class CheckpointReader { public: CheckpointReader(const string& filepattern, TF_Status* out_status); - ~CheckpointReader(); bool HasTensor(const string& name) const; const string DebugString() const; - // Returns a map from variable names to its shape. Slices of a partitioned + // Returns a map from variable names to their shapes. Slices of a partitioned // tensor are combined into a single entry. const TensorSliceReader::VarToShapeMap& GetVariableToShapeMap() const; + // Returns a map from variable names to their data types. Slices of a + // partitioned tensor are combined into a single entry. + const TensorSliceReader::VarToDataTypeMap& GetVariableToDataTypeMap() const; + // Attempts to look up the tensor named "name" and stores the found result in // "out_tensor". void GetTensor(const string& name, @@ -54,14 +59,19 @@ class CheckpointReader { TF_Status* out_status) const; private: - // Uses "v2_reader_" to build a "var name -> shape" map; owned by caller. + // Uses "v2_reader_" to build "var name -> shape" and "var name -> data type" + // maps; both owned by caller. // REQUIRES: "v2_reader_ != nullptr && v2_reader_.status().ok()". - TensorSliceReader::VarToShapeMap* BuildV2VarToShapeMap(); + std::pair, + std::unique_ptr > + BuildV2VarMaps(); + + // Invariant: exactly one of "reader_" and "v2_reader_" is non-null. + std::unique_ptr reader_; + std::unique_ptr v2_reader_; - // Invariant: exactly one of "reader_" and "v2_reader_" is non-nullptr. - TensorSliceReader* reader_; // Owned. - BundleReader* v2_reader_; // Owned. - TensorSliceReader::VarToShapeMap* var_to_shape_map_ptr_; // Owned. + std::unique_ptr var_to_shape_map_; + std::unique_ptr var_to_data_type_map_; TF_DISALLOW_COPY_AND_ASSIGN(CheckpointReader); }; diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 52945d32391ddcd9bddb7726ddac68ee1ba9ae58..96f3c3e195e7025252c1e3cda5436237ad89257b 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -10,13 +10,15 @@ load( tf_cuda_library( name = "c_api", - srcs = ["c_api.cc"], + srcs = [ + "c_api.cc", + "c_api_internal.h", + ], hdrs = ["c_api.h"], copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ @@ -33,6 +35,21 @@ tf_cuda_library( }), ) +tf_cuda_library( + name = "c_api_internal", + hdrs = ["c_api_internal.h"], + deps = [ + ":c_api", + ":runtime", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework_internal", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_internal", + ], +) + tf_cc_test( name = "c_api_test", srcs = ["c_api_test.cc"], @@ -53,7 +70,6 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = select({ "//tensorflow:android": [ - ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ @@ -85,3 +101,14 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "tape", + srcs = ["tape.cc"], + hdrs = ["tape.h"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 801d7307494e6585fbb7ee0fa4e6724ebe2c6f94..514a4010bc81bb280c3a1208b57a5db752f52f8a 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -51,64 +52,6 @@ string DeviceName(tensorflow::Device* d) { } } // namespace -struct TFE_Context { - explicit TFE_Context(TF_Session* s) : session(s) {} - - // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. - TF_Session* session; - tensorflow::Rendezvous* rendezvous; - - tensorflow::mutex functions_mu; - tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ - tensorflow::OpRegistry::Global(), {}}; - - // One FunctionLibraryRuntime per device. - // func_libs[i] is the FunctionLibraryRuntime corresponding to - // session->devices[i]. - std::unique_ptr pflr; - - std::unordered_map - kernel_cache; - - tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { - return pflr->GetFLR(d->name()); - } - - const std::vector& devices() { return session->devices; } -}; - -struct TFE_TensorHandle { - TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d) - : t(t), d(d) {} - - tensorflow::Tensor t; - // TODO(ashankar): d == nullptr iff local CPU - // This was expedient, but perhaps worth revisiting ('d' should always be a - // valid pointer?) - // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are - // provided with the appropriate TFE_Context. - // - // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a - // TFE_TensorHandle does not outlive the TFE_Context from which it came? - tensorflow::Device* d; -}; - -struct TFE_Op { - TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) - : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} - - bool const is_function() const { return attr_types == nullptr; } - - TFE_Context* ctx; // Must outlive the TFE_Op. - const string name; - tensorflow::AttrBuilder attrs; - const tensorflow::AttrTypeMap* attr_types; - std::vector inputs; - std::vector input_devices; - tensorflow::Device* device; -}; - extern "C" { TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { @@ -330,6 +273,20 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, return ret; } +TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, + const char* op_or_function_name, + const char* attr_name, unsigned char* is_list, + TF_Status* status) { + TF_AttrType ret; + TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status); + if (!status->status.ok()) { + return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType. + } + ret = TFE_OpGetAttrType(op, attr_name, is_list, status); + TFE_DeleteOp(op); + return ret; +} + void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { op->attrs.Set(attr_name, value); } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index a4f7d308fbb4008d00bd97abf40c9ead5fdb1986..9bfa63711b5360b33819434f9a551030e0f988c8 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -107,6 +107,12 @@ TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_St TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status); +// Get an attribute type given an op name; a fusion of TFE_NewOp and +// TFE_OpGetAttrType for use from Python without the overhead of the individual +// calls and memory management of TFE_Op. +TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( + TFE_Context* ctx, const char* op_or_function_name, const char* attr_name, + unsigned char* is_list, TF_Status* status); TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..712526f17002a612a145f80538977fedfde00038 --- /dev/null +++ b/tensorflow/c/eager/c_api_internal.h @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ + +#include "tensorflow/c/eager/c_api.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/runtime.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +struct TFE_Context { + explicit TFE_Context(TF_Session* s) : session(s) {} + + // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. + TF_Session* session; + tensorflow::Rendezvous* rendezvous; + + tensorflow::mutex functions_mu; + tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ + tensorflow::OpRegistry::Global(), {}}; + + // One FunctionLibraryRuntime per device. + // func_libs[i] is the FunctionLibraryRuntime corresponding to + // session->devices[i]. + std::unique_ptr pflr; + + std::unordered_map + kernel_cache; + + tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { + return pflr->GetFLR(d->name()); + } + + const std::vector& devices() { return session->devices; } +}; + +struct TFE_TensorHandle { + TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d) + : t(t), d(d) {} + + tensorflow::Tensor t; + // TODO(ashankar): d == nullptr iff local CPU + // This was expedient, but perhaps worth revisiting ('d' should always be a + // valid pointer?) + // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are + // provided with the appropriate TFE_Context. + // + // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a + // TFE_TensorHandle does not outlive the TFE_Context from which it came? + tensorflow::Device* d; +}; + +struct TFE_Op { + TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) + : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} + + bool const is_function() const { return attr_types == nullptr; } + + TFE_Context* ctx; // Must outlive the TFE_Op. + const tensorflow::string name; + tensorflow::AttrBuilder attrs; + const tensorflow::AttrTypeMap* attr_types; + std::vector inputs; + std::vector input_devices; + tensorflow::Device* device; +}; + +#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/tape.cc b/tensorflow/c/eager/tape.cc new file mode 100644 index 0000000000000000000000000000000000000000..464612a81ebda428f5582b6927f3a3b00a5aa6f5 --- /dev/null +++ b/tensorflow/c/eager/tape.cc @@ -0,0 +1,102 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/tape.h" + +namespace tensorflow { +namespace eager { + +bool GradientTape::ShouldRecord(gtl::ArraySlice tensor_ids) { + for (int64 i : tensor_ids) { + if (tensor_tape_.find(i) != tensor_tape_.end()) { + return true; + } + } + return false; +} + +void GradientTape::Watch(int64 tensor_id) { + tensor_tape_.emplace(tensor_id, -1); +} + +void GradientTape::RecordOperation( + const string& op_type, gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, void* backward_function, + const std::function& backward_function_deleter) { + if (!ShouldRecord(input_tensor_id)) { + backward_function_deleter(); + return; + } + std::vector ids; + ids.reserve(input_tensor_id.size()); + for (int64 i : input_tensor_id) { + tensor_usage_[i]++; + ids.push_back(i); + } + const int64 op_id = next_op_id_++; + std::vector tensors; + tensors.reserve(output_tensors.size()); + for (const TapeTensor& o : output_tensors) { + // Note: the tensor can have already been watched and hence be in the tape, + // so we cannot check that we're inserting it here. + tensor_tape_[o.id] = op_id; + tensor_usage_[o.id] = 1; + tensors.push_back(o); + } + op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function, + backward_function_deleter}; +} + +void GradientTape::DeleteTrace(int64 tensor_id) { + auto it = tensor_usage_.find(tensor_id); + if (it == tensor_usage_.end()) { + return; + } + it->second--; + if (it->second != 0) { + return; + } + tensor_usage_.erase(it); + auto tensor_op_it = tensor_tape_.find(tensor_id); + if (tensor_op_it == tensor_tape_.end()) { + return; + } + const int64 op_id = tensor_op_it->second; + if (op_id == -1) { + // Do not delete watched tensors. + return; + } + tensor_tape_.erase(tensor_op_it); + auto op_it = op_tape_.find(op_id); + CHECK(op_it != op_tape_.end()); + for (const auto& output : op_it->second.output_tensor_info) { + if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + // Found a usage for an output, so cannot delete the op. + return; + } + } + for (int64 id : op_it->second.input_tensor_id) { + DeleteTrace(id); + } + op_it->second.backward_function_deleter(); + op_tape_.erase(op_it); +} + +std::pair GradientTape::Export() { + return {std::move(tensor_tape_), std::move(op_tape_)}; +} + +} // namespace eager +} // namespace tensorflow diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h new file mode 100644 index 0000000000000000000000000000000000000000..df51f300eb61d54cb1e06d5a58a9b10e834f73c4 --- /dev/null +++ b/tensorflow/c/eager/tape.h @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TAPE_H_ +#define TENSORFLOW_C_EAGER_TAPE_H_ + +// Language-agnostic gradient tape. Does not perform backpropagation, just +// maintains the data structures required to do so. + +#include +#include +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace eager { + +// Information about a tensor. +struct TapeTensor { + int64 id; // Expected to be unique in the lifetime of this process. + DataType dtype; + TensorShape shape; +}; + +// Represents an entry in the tape. +struct OpTapeEntry { + string op_type; + std::vector output_tensor_info; + std::vector input_tensor_id; + + // TODO(apassos) consider narrowing down this interface. + void* backward_function; + + // Should be called before deleting the backward function. TODO(apassos) use + // unique_ptrs to ensure this happens. + std::function backward_function_deleter; +}; + +// Map from tensor_id to internally-defined operation-id of the operation which +// produced this tensor. A value of -1 means that the tensor was directly +// watched and not the result of any operation in the tape. +using TensorTape = std::unordered_map; + +// Map from operation-id to tape entry. +using OpTape = std::unordered_map; + +// Traces the execution of operations, doing eager garbage collection, and +// exporting a full trace so other code can do backpropagation. Not thread-safe. +class GradientTape { + public: + GradientTape() {} + + bool ShouldRecord(gtl::ArraySlice tensor_ids); + + void Watch(int64 tensor_id); + + void RecordOperation(const string& op_type, + gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, + void* backward_function, + const std::function& backward_function_deleter); + + void DeleteTrace(int64 tensor_id); + + // Note: it is only valid to call Export once per tape, and after calling + // export the tape is no longer valid (i.e. calls to ShouldRecord, Watch, + // Record, and Delete have undefined behavior). + std::pair Export(); + + private: + TensorTape tensor_tape_; + OpTape op_tape_; + int64 next_op_id_{0}; + + // Map from tensor id to number of remaining usages (i.e. how many entries in + // the tape refer to it); to aid in tape garbage collection. + std::unordered_map tensor_usage_; +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TAPE_H_ diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc index 2423d83dda93938aa1a2ba0ed0ed7356bd65d39f..d2d887f32c44af5980b50785f282187d0f6fcff4 100644 --- a/tensorflow/c/while_loop_test.cc +++ b/tensorflow/c/while_loop_test.cc @@ -318,7 +318,7 @@ TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) { // TODO(skyewm): this error message could be more informative. Add explicit // checks for this case in the while loop implementation? ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) { @@ -358,7 +358,7 @@ TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) { // TODO(skyewm): this error message could be more informative. Add explicit // checks for this case in the while loop implementation? ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } // TODO(skyewm): enable this when it works (currently segfaults!) @@ -389,7 +389,7 @@ TEST_F(CApiWhileLoopTest, WrongGraph) { params_->body_outputs[0] = inputs_[0]; // TODO(skyewm): improve error message ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } TEST_F(CApiWhileLoopTest, BadTypes) { diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 3682ebd9432f0107d3b93465cdc9afc900fab029..80112f9b44b1d5fd65a7d47788b072dc47a2b29a 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -45,6 +45,7 @@ tf_cc_test( srcs = ["framework/gradients_test.cc"], deps = [ ":cc_ops", + ":client_session", ":grad_op_registry", ":grad_ops", ":gradients", diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 0ec5b9a1bdae61c10bfacb96978314912c76b7c5..affd90b1bcc7cb4a8b3ffed6aeeb4bd480f5e314 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -91,6 +91,13 @@ class SymbolicGradientBuilder { // `summed_grads` is the sum of `exit_node`s gradients. Status ProcessWhileLoop(Node* exit_node, const Output& summed_grads); + // Gets the set of node ids at which to stop backprop. These are all elements + // of `outputs_` that do not get transitively consumed by other `outputs_`. + // Used to identify nodes at which to stop backprop. + std::unordered_set GetStopBackpropNodes( + const std::vector& reachable_nodes, + std::unordered_set output_nodes); + const Scope& scope_; const ops::GradOpRegistry* registry_; const std::vector& outputs_; @@ -117,10 +124,6 @@ class SymbolicGradientBuilder { // gradients from `grad_inputs_`. std::deque ready_; - // The set of node ids in `outputs_`. Used to identify nodes at which to stop - // backprop. - std::unordered_set output_nodes_; - // The set of node ids in `inputs_`. Used to identify nodes at backprop // frontier. Maps from Output -> index into `grad_outputs_`. std::unordered_map input_nodes_; @@ -186,6 +189,63 @@ std::vector SymbolicGradientBuilder::GetReachableNodes() { return reachable_nodes; } +std::unordered_set SymbolicGradientBuilder::GetStopBackpropNodes( + const std::vector& reachable_nodes, + std::unordered_set output_nodes) { + // Output nodes that get transitively consumed by other `outputs_` are stored + // in `internal_outputs`. + std::unordered_set internal_outputs; + std::unordered_set visited; + // Initialize `queue` for BFS traversal. Nodes in `queue` hold upcoming nodes + // along with the last Node in `output_` encountered along that path. If no + // `output_` node was encountered, pair.second will be nullptr. + std::deque> queue; + for (const Output& nout : inputs_) { + if (visited.find(nout.node()) == visited.end()) { + queue.push_back(std::make_pair(nout.node(), static_cast(nullptr))); + visited.insert(nout.node()); + } + } + // BFS from nodes in 'inputs_' along out edges for the entire graph. Internal + // output nodes are recorded during the traversal. All nodes that are output + // nodes but not internal output nodes are considered the frontier of the + // output nodes, and thus our stop backprop nodes. + while (!queue.empty()) { + std::pair p = queue.front(); + Node* n = p.first; + queue.pop_front(); + for (const Edge* e : n->out_edges()) { + // If a node is not reachable from outputs_, we can stop. + if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue; + if (visited.find(e->dst()) != visited.end()) continue; + + int node_id = e->dst()->id(); + Node* last_output_node = p.second; + if (output_nodes.find(node_id) != output_nodes.end()) { + // We reached an output node. + if (last_output_node != nullptr) { + // If we had already found an output node on this path so we mark + // it as an internal output. + internal_outputs.insert(last_output_node->id()); + } + // Mark this newly found output node to insert in the queue. + last_output_node = e->dst(); + } + queue.push_back(std::make_pair(e->dst(), last_output_node)); + visited.insert(e->dst()); + } + } + // Finally, we set stop_backprop_nodes to all output_nodes that aren't also + // internal_outputs. + std::unordered_set stop_backprop_nodes; + for (int output_node : output_nodes) { + if (internal_outputs.find(output_node) == internal_outputs.end()) { + stop_backprop_nodes.insert(output_node); + } + } + return stop_backprop_nodes; +} + Status SymbolicGradientBuilder::Initialize() { if (outputs_.size() != grad_inputs_.size()) { return errors::InvalidArgument( @@ -202,11 +262,16 @@ Status SymbolicGradientBuilder::Initialize() { } grad_outputs_->clear(); grad_outputs_->resize(inputs_.size()); - // Populate `output_nodes_` from node ids in `outputs_`. - output_nodes_.reserve(outputs_.size()); + + std::unordered_set output_nodes; + output_nodes.reserve(outputs_.size()); for (size_t i = 0; i < outputs_.size(); ++i) { - output_nodes_.insert(outputs_[i].node()->id()); + output_nodes.insert(outputs_[i].node()->id()); } + + std::unordered_set stop_backprop_nodes = + GetStopBackpropNodes(reachable_nodes, output_nodes); + // Populate `input_nodes_` from Outputs in `inputs_`. input_nodes_.reserve(inputs_.size()); for (size_t i = 0; i < inputs_.size(); ++i) { @@ -237,7 +302,7 @@ Status SymbolicGradientBuilder::Initialize() { backprops_[{n, i}].clear(); } int num_expected_backprops = 0; - if (output_nodes_.find(n->id()) == output_nodes_.end()) { + if (stop_backprop_nodes.find(n->id()) == stop_backprop_nodes.end()) { // Internal node: continue BFS along connected outputs. for (const Edge* e : n->out_edges()) { // If a node is not reachable from outputs_, @@ -250,9 +315,10 @@ Status SymbolicGradientBuilder::Initialize() { } ++num_expected_backprops; } - } else { - // Output node: stop BFS and update `num_expected_backprops` for - // each Output in `outputs_` that references `n`. + } + if (output_nodes.find(n->id()) != output_nodes.end()) { + // Output node: update `num_expected_backprops` for each Output in + // `outputs_` that references `n`. for (const Output& output : outputs_) { if (output.node() == n) { ++num_expected_backprops; @@ -323,7 +389,7 @@ Status SymbolicGradientBuilder::CallGradFunction( Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node, const Output& summed_grads) { - // TOOD(skyewm): detect second-order gradient and return bad status + // TODO(skyewm): detect second-order gradient and return bad status // TODO(skyewm): handle (or at least detect) nested while loops // TODO(skyewm): handle NoGradient in while loop diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index dcaf10c340c61142c6f436f74285ea29a83630a9..07a062e704ed6ffc6389b5897309957a1bfcd1c2 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -453,6 +454,45 @@ TEST_F(GradientsTest, UnreachableInput) { " for node 'z' as it's unreachable from the output node(s)."); } +TEST_F(GradientsTest, DependentOutputs) { + auto x = Placeholder(scope_test_, DT_FLOAT); + auto y0 = Square(scope_test_, x); + auto y1 = Square(scope_test_, y0); + auto y2 = Square(scope_test_, y1); + // Requesting the gradients for y0 and y2 should return the sum of their + // individual gradients. + std::vector grad_outputs; + TF_EXPECT_OK(AddSymbolicGradients(scope_test_, {y0, y2}, {x}, &grad_outputs)); + ClientSession session(scope_test_); + std::vector grad_result; + TF_EXPECT_OK(session.Run({{x, {3.0f}}}, grad_outputs, &grad_result)); + EXPECT_EQ(grad_result.size(), 1); + EXPECT_EQ(grad_result[0].NumElements(), 1); + EXPECT_EQ(grad_result[0].flat()(0), 17502.0f); +} + +TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) { + auto x = Placeholder(scope_test_, DT_FLOAT); + auto y0 = Square(scope_test_, x); + // y1, y2, and y3 all use y0. This means the backwards pass will need to wait + // for the gradient for all three. + auto y1 = Square(scope_test_, y0); + auto y2 = Square(scope_test_, y0); + auto y3 = Square(scope_test_, y2); + std::vector grad_outputs; + // By requesting y0, y1, and y3 we test that the computation correctly waits + // for all the points in backprop where gradients need to be summed from + // multiple branches. + TF_EXPECT_OK( + AddSymbolicGradients(scope_test_, {y0, y1, y3}, {x}, &grad_outputs)); + ClientSession session(scope_test_); + std::vector grad_result; + TF_EXPECT_OK(session.Run({{x, {3.0f}}}, grad_outputs, &grad_result)); + EXPECT_EQ(grad_result.size(), 1); + EXPECT_EQ(grad_result[0].NumElements(), 1); + EXPECT_EQ(grad_result[0].flat()(0), 17610.0f); +} + // StopGradientSingleOutputMultiEdgeTest tests combinations of valid and // 'NoGradient' (induced by StopGradient op) returned along multiple edges from // a single nodes output. diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index ac288b1d834d267f5bab887f45de8173e31f88ea..d7446b9560fd7dc8377ea3710641906b274313a9 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define _USE_MATH_DEFINES +#include + #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/math_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -200,8 +203,8 @@ Status TanhGrad(const Scope& scope, const Operation& op, // evaluated. Scope grad_scope = scope.WithControlDependencies(grad); auto y = ConjugateHelper(grad_scope, op.output(0)); - grad_outputs->push_back(internal::TanhGrad(scope, y, grad)); - return scope.status(); + grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad)); + return grad_scope.status(); } REGISTER_GRADIENT_OP("Tanh", TanhGrad); @@ -256,8 +259,8 @@ Status SigmoidGrad(const Scope& scope, const Operation& op, // evaluated. Scope grad_scope = scope.WithControlDependencies(grad); auto y = ConjugateHelper(grad_scope, op.output(0)); - grad_outputs->push_back(internal::SigmoidGrad(scope, y, grad)); - return scope.status(); + grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad)); + return grad_scope.status(); } REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); @@ -484,7 +487,7 @@ Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, auto grad = grad_inputs[0]; auto zeros = ZerosLike(scope, grad); auto gx_1 = Where3(scope, comparator, grad, zeros); - auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros); + auto gx_2 = Where3(scope, comparator, zeros, grad); return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); } @@ -696,15 +699,32 @@ Status MeanGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Mean", MeanGrad); +Status ErfGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto grad = grad_inputs[0]; + auto two_over_root_pi = Cast(scope, Const(scope, 2 / std::sqrt(M_PI)), + grad.type()); + Scope grad_scope = scope.WithControlDependencies(grad); + auto x = ConjugateHelper(grad_scope, op.input(0)); + // grad * 2/sqrt(pi) * exp(-x**2) + auto dx = Mul(grad_scope, + Mul(grad_scope, grad, two_over_root_pi), + Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x)))); + grad_outputs->push_back(dx); + return grad_scope.status(); +} +REGISTER_GRADIENT_OP("Erf", ErfGrad); + Status LgammaGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { auto grad = grad_inputs[0]; Scope grad_scope = scope.WithControlDependencies(grad); auto x = ConjugateHelper(grad_scope, op.input(0)); - auto dx = Mul(scope, grad, Digamma(scope, x)); + auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x)); grad_outputs->push_back(dx); - return scope.status(); + return grad_scope.status(); } REGISTER_GRADIENT_OP("Lgamma", LgammaGrad); diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index a174f223ad59b0a111b3d13cb59fb2b13a0095b0..6313f41da5e5f9cf88be4c8a84408a8df77f0e25 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -64,7 +64,9 @@ class CWiseUnaryGradTest : public ::testing::Test { IMAG, CONJ, COMPLEX, - ANGLE + ANGLE, + LGAMMA, + ERF }; template @@ -168,6 +170,12 @@ class CWiseUnaryGradTest : public ::testing::Test { case ANGLE: y = Angle(scope_, x); break; + case LGAMMA: + y = Lgamma(scope_, x); + break; + case ERF: + y = Erf(scope_, x); + break; } float max_error; @@ -503,6 +511,42 @@ TEST_F(CWiseUnaryGradTest, Angle) { TestCWiseGrad(ANGLE, x_fn); } +TEST_F(CWiseUnaryGradTest, Lgamma) { + auto x_fn = [this](const int i) { + return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5}); + }; + TestCWiseGrad(LGAMMA, x_fn); +} + +TEST_F(CWiseUnaryGradTest, Lgamma_Complex) { + auto x_fn = [this](const int i) { + return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}}); + }; + // TODO(kbsriram) + // Add test when the lgamma kernel supports complex numbers + if (false) { + TestCWiseGrad(LGAMMA, x_fn); + } +} + +TEST_F(CWiseUnaryGradTest, Erf) { + auto x_fn = [this](const int i) { + return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3}); + }; + TestCWiseGrad(ERF, x_fn); +} + +TEST_F(CWiseUnaryGradTest, Erf_Complex) { + auto x_fn = [this](const int i) { + return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}}); + }; + // TODO(kbsriram) + // Add test when the erf kernel supports complex numbers + if (false) { + TestCWiseGrad(ERF, x_fn); + } +} + class MathGradTest : public ::testing::Test { protected: MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} @@ -821,17 +865,5 @@ TEST_F(NaryGradTest, Minimum) { RunTest(x, x_init_value, y, shape); } -TEST_F(NaryGradTest, Lgamma) { - TensorShape shape({3, 2}); - auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); - auto y = Lgamma(scope_, x); - // Select values to avoid instability when computing finite differences. - // Ref: https://en.wikipedia.org/wiki/File:Gamma_plot.svg - Tensor x_init_value = - test::AsTensor({-3.5f, -2.5f, -1.5f, 1.0f, 2.0f, 3.5f}, {3, 2}); - RunTest(x, x_init_value, y, shape); - // TODO(suharshs): add test case for complex values -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/ops/const_op.cc b/tensorflow/cc/ops/const_op.cc index 0030c2b2a7b69afe2151e88ef5d6d0755f72bfa7..a04f37067dd95fc452b8e565c6bc73128c0423b2 100644 --- a/tensorflow/cc/ops/const_op.cc +++ b/tensorflow/cc/ops/const_op.cc @@ -19,19 +19,17 @@ limitations under the License. namespace tensorflow { namespace ops { -Output Const(const Scope& scope, const Input::Initializer& val) { +namespace { +template +Output ConstHelper(const Scope& scope, const T& value, DataType dtype) { if (!scope.ok()) return Output(); - if (!val.status.ok()) { - scope.UpdateStatus(val.status); - return Output(); - } Node* ret; Graph* graph = scope.graph(); const string unique_name = scope.GetUniqueNameForOp("Const"); auto builder = NodeBuilder(unique_name, "Const") - .Attr("value", val.tensor) - .Attr("dtype", val.tensor.dtype()); + .Attr("value", value) + .Attr("dtype", dtype); scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(graph, &ret)); if (!scope.ok()) return Output(); @@ -41,6 +39,19 @@ Output Const(const Scope& scope, const Input::Initializer& val) { return Output(ret); } +} // namespace + +Output Const(const Scope& scope, const Input::Initializer& val) { + if (!val.status.ok()) { + scope.UpdateStatus(val.status); + return Output(); + } + return ConstHelper(scope, val.tensor, val.tensor.dtype()); +} + +Output ConstFromProto(const Scope& scope, const TensorProto& proto) { + return ConstHelper(scope, proto, proto.dtype()); +} NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp) { if (!inp.status().ok()) { diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h index 516800920f282be0590ef72b26a7fdd8b92a38f9..d11fda475b3db58bf83cdb94079c8fde8d1170f7 100644 --- a/tensorflow/cc/ops/const_op.h +++ b/tensorflow/cc/ops/const_op.h @@ -28,6 +28,8 @@ namespace ops { Output Const(const Scope& scope, const Input::Initializer& val); +Output ConstFromProto(const Scope& scope, const TensorProto& proto); + NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp); template diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc index 3184edeb3307cafcbfbc41c6477fd092ab613b46..69b5d7fd47cae9b54d3e0ae42b0d3936e3c7c696 100644 --- a/tensorflow/cc/ops/const_op_test.cc +++ b/tensorflow/cc/ops/const_op_test.cc @@ -100,6 +100,20 @@ TEST(ConstOpTest, WithExplicitShape) { ExpectNodeEqual(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3}); } +TEST(ConstOpTest, FromProto) { + Scope root = Scope::NewRootScope(); + TensorProto proto; + proto.set_dtype(DT_DOUBLE); + TensorShape({2, 2}).AsProto(proto.mutable_tensor_shape()); + for (int i = 0; i < 4; ++i) { + proto.add_double_val(static_cast(i)); + } + auto c = ops::ConstFromProto(root, proto); + TF_CHECK_OK(root.status()); + EXPECT_EQ(c.op().output_type(0), DT_DOUBLE); + ExpectNodeEqual(c.node(), {0.0, 1.0, 2.0, 3.0}, {2, 2}); +} + TEST(ConstOpTest, InvalidInitializer) { Scope root = Scope::NewRootScope(); ops::Const(root, {{2.0}, {"df"}}); diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 1cc7cf3f2021ede8269368aa46007b5ceaace606..67b2e4b81a985731ad5e41ce68a5aeaa9fcef6b9 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -56,6 +56,7 @@ cc_library( ":constants", ] + if_not_mobile([ "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -88,6 +89,7 @@ tf_cc_test( filegroup( name = "saved_model_half_plus_two", srcs = glob([ + "testdata/half_plus_two_forward_compatibility/**", "testdata/half_plus_two_pbtxt/**", "testdata/half_plus_two_main_op/**", "testdata/half_plus_two/**", diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index f98abc8a817eca7bc129bb03a2ad31b97d957065..462308a48f1e64d368b2a29cde8b6180b2552f2f 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -224,6 +225,18 @@ Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, return Status::OK(); } +// For forward compatibility, remove new default attributes from the graph def +// that were not present in the consumer (e.g. If graph was exported using +// code that's newer than the server and a new default attr was added). +Status RemoveNewDefaultAttrsFromMetaGraphDef(MetaGraphDef* meta_graph_def) { + OpListOpRegistry producer_op_registry( + &meta_graph_def->meta_info_def().stripped_op_list()); + OpRegistry* consumer_op_registry = OpRegistry::Global(); + return RemoveNewDefaultAttrsFromGraphDef(meta_graph_def->mutable_graph_def(), + *consumer_op_registry, + producer_op_registry, nullptr); +} + Status LoadSavedModelInternal(const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, @@ -241,6 +254,9 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, TF_RETURN_IF_ERROR( FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def)); + TF_RETURN_IF_ERROR( + RemoveNewDefaultAttrsFromMetaGraphDef(&bundle->meta_graph_def)); + TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession( bundle->meta_graph_def, session_options, &bundle->session)); diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 0ad6b33bba5fcceaca68e2f179cef2232c689a80..6dd14837b5e9f31395a26b49082e3339817473d0 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -29,10 +29,12 @@ limitations under the License. namespace tensorflow { namespace { -constexpr char kTestDataPbTxt[] = - "cc/saved_model/testdata/half_plus_two_pbtxt/00000123"; +constexpr char kTestDataForwardCompatibility[] = + "cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123"; constexpr char kTestDataMainOp[] = "cc/saved_model/testdata/half_plus_two_main_op/00000123"; +constexpr char kTestDataPbTxt[] = + "cc/saved_model/testdata/half_plus_two_pbtxt/00000123"; constexpr char kTestDataSharded[] = "cc/saved_model/testdata/half_plus_two/00000123"; @@ -167,6 +169,24 @@ TEST_F(LoaderTest, PbtxtFormat) { CheckSavedModelBundle(export_dir, bundle); } +// Forward compatibility graph has a new attr with a default value equal to the +// value used by the server. If we handle new default attrs correctly, this test +// will pass. This simulates adding new atts to the training code while server +// code lags behind. +TEST_F(LoaderTest, ForwardCompatibility) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + // TODO(b/67753689): Add support for regenerating this model in the export + // code. + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataForwardCompatibility); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + CheckSavedModelBundle(export_dir, bundle); +} + TEST_F(LoaderTest, MainOpFormat) { SavedModelBundle bundle; SessionOptions session_options; diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/assets/foo.txt new file mode 100644 index 0000000000000000000000000000000000000000..f9ff036688007836524129e23f5cf82edd1e8910 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/assets/foo.txt @@ -0,0 +1 @@ +asset-file-contents \ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/saved_model.pbtxt b/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/saved_model.pbtxt new file mode 100755 index 0000000000000000000000000000000000000000..e799b3579c6e79de83989d4f19662becae4a5301 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/saved_model.pbtxt @@ -0,0 +1,2728 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "Assign" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + attr { + name: "validate_shape" + type: "bool" + default_value { + b: true + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: true + } + } + allows_uninitialized_input: true + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "Identity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "MergeV2Checkpoints" + input_arg { + name: "checkpoint_prefixes" + type: DT_STRING + } + input_arg { + name: "destination_prefix" + type: DT_STRING + } + attr { + name: "delete_old_dirs" + type: "bool" + default_value { + b: true + } + } + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "NoOp" + } + op { + name: "Pack" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "axis" + type: "int" + default_value { + i: 0 + } + } + } + op { + name: "ParseExample" + input_arg { + name: "serialized" + type_attr: "TInputs" + } + input_arg { + name: "names" + type: DT_STRING + } + input_arg { + name: "sparse_keys" + type: DT_STRING + number_attr: "Nsparse" + } + input_arg { + name: "dense_keys" + type: DT_STRING + number_attr: "Ndense" + } + input_arg { + name: "dense_defaults" + type_list_attr: "Tdense" + } + output_arg { + name: "sparse_indices" + type: DT_INT64 + number_attr: "Nsparse" + } + output_arg { + name: "sparse_values" + type_list_attr: "sparse_types" + } + output_arg { + name: "sparse_shapes" + type: DT_INT64 + number_attr: "Nsparse" + } + output_arg { + name: "dense_values" + type_list_attr: "Tdense" + } + attr { + name: "Nsparse" + type: "int" + has_minimum: true + } + attr { + name: "TInputs" + type: "type" + default_value { + type: DT_STRING + } + allowed_values { + list { + type: DT_STRING + type: DT_INT64 + } + } + } + attr { + name: "Ndense" + type: "int" + has_minimum: true + } + attr { + name: "sparse_types" + type: "list(type)" + has_minimum: true + allowed_values { + list { + type: DT_FLOAT + type: DT_INT64 + type: DT_STRING + } + } + } + attr { + name: "Tdense" + type: "list(type)" + has_minimum: true + allowed_values { + list { + type: DT_FLOAT + type: DT_INT64 + type: DT_STRING + } + } + } + attr { + name: "dense_shapes" + type: "list(shape)" + has_minimum: true + } + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "Reshape" + input_arg { + name: "tensor" + type_attr: "T" + } + input_arg { + name: "shape" + type_attr: "Tshape" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshape" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RestoreV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + output_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + } + op { + name: "SaveV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + input_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + } + op { + name: "ShardedFilename" + input_arg { + name: "basename" + type: DT_STRING + } + input_arg { + name: "shard" + type: DT_INT32 + } + input_arg { + name: "num_shards" + type: DT_INT32 + } + output_arg { + name: "filename" + type: DT_STRING + } + } + op { + name: "StringJoin" + input_arg { + name: "inputs" + type: DT_STRING + number_attr: "N" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "separator" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "VariableV2" + output_arg { + name: "ref" + type_attr: "dtype" + is_ref: true + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true + } + } + tags: "serve" + tensorflow_version: "1.1.0-rc2" + tensorflow_git_version: "unknown" + } + graph_def { + node { + name: "a/initial_value" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node { + name: "a" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "a/Assign" + op: "Assign" + input: "a" + input: "a/initial_value" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@a" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "a/read" + op: "Identity" + input: "a" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@a" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "b/initial_value" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 2.0 + } + } + } + } + node { + name: "b" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "b/Assign" + op: "Assign" + input: "b" + input: "b/initial_value" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@b" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "b/read" + op: "Identity" + input: "b" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@b" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "c/initial_value" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 3.0 + } + } + } + } + node { + name: "c" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "c/Assign" + op: "Assign" + input: "c" + input: "c/initial_value" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@c" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "c/read" + op: "Identity" + input: "c" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@c" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "tf_example" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + unknown_rank: true + } + } + } + } + node { + name: "ParseExample/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "ParseExample/key_x2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "ParseExample/Reshape/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "ParseExample/Reshape" + op: "Reshape" + input: "ParseExample/key_x2" + input: "ParseExample/Reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "ParseExample/ParseExample/names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "ParseExample/ParseExample/dense_keys_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "x" + } + } + } + } + node { + name: "ParseExample/ParseExample/dense_keys_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "x2" + } + } + } + } + node { + name: "ParseExample/ParseExample" + op: "ParseExample" + input: "tf_example" + input: "ParseExample/ParseExample/names" + input: "ParseExample/ParseExample/dense_keys_0" + input: "ParseExample/ParseExample/dense_keys_1" + input: "ParseExample/Const" + input: "ParseExample/Reshape" + attr { + key: "Ndense" + value { + i: 2 + } + } + attr { + key: "TInputs" + value { + type: DT_STRING + } + } + attr { + key: "Nsparse" + value { + i: 0 + } + } + attr { + key: "Tdense" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + attr { + key: "dense_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "sparse_types" + value { + list { + } + } + } + } + node { + name: "x" + op: "Identity" + input: "ParseExample/ParseExample" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "Mul" + op: "Mul" + input: "a/read" + input: "x" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "y" + op: "Add" + input: "Mul" + input: "b/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "Mul_1" + op: "Mul" + input: "a/read" + input: "x" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "y2" + op: "Add" + input: "Mul_1" + input: "c/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "x2" + op: "Identity" + input: "ParseExample/ParseExample:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "Mul_2" + op: "Mul" + input: "a/read" + input: "x2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "y3" + op: "Add" + input: "Mul_2" + input: "c/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "/tmp/original/export/assets/foo.txt" + } + } + } + } + node { + name: "filename_tensor/initial_value" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "foo.txt" + } + } + } + } + node { + name: "filename_tensor" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "filename_tensor/Assign" + op: "Assign" + input: "filename_tensor" + input: "filename_tensor/initial_value" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@filename_tensor" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "filename_tensor/read" + op: "Identity" + input: "filename_tensor" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@filename_tensor" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Assign/value" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "foo.txt" + } + } + } + } + node { + name: "Assign" + op: "Assign" + input: "filename_tensor" + input: "Assign/value" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@filename_tensor" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "init" + op: "NoOp" + input: "^a/Assign" + input: "^b/Assign" + input: "^c/Assign" + } + node { + name: "group_deps" + op: "NoOp" + input: "^Assign" + } + node { + name: "save/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "model" + } + } + } + } + node { + name: "save/StringJoin/inputs_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp_80e928f1e0c844239d136d1ea966099d/part" + } + } + } + } + node { + name: "save/StringJoin" + op: "StringJoin" + input: "save/Const" + input: "save/StringJoin/inputs_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "separator" + value { + s: "" + } + } + } + node { + name: "save/num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "save/ShardedFilename/shard" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "save/ShardedFilename" + op: "ShardedFilename" + input: "save/StringJoin" + input: "save/ShardedFilename/shard" + input: "save/num_shards" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/SaveV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 3 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 3 + } + } + string_val: "a" + string_val: "b" + string_val: "c" + } + } + } + } + node { + name: "save/SaveV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 3 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 3 + } + } + string_val: "" + string_val: "" + string_val: "" + } + } + } + } + node { + name: "save/SaveV2" + op: "SaveV2" + input: "save/ShardedFilename" + input: "save/SaveV2/tensor_names" + input: "save/SaveV2/shape_and_slices" + input: "a" + input: "b" + input: "c" + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + } + } + } + } + node { + name: "save/control_dependency" + op: "Identity" + input: "save/ShardedFilename" + input: "^save/SaveV2" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@save/ShardedFilename" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "save/ShardedFilename" + input: "^save/control_dependency" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "save/MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "save/MergeV2Checkpoints/checkpoint_prefixes" + input: "save/Const" + attr { + key: "delete_old_dirs" + value { + b: true + } + } + } + node { + name: "save/Identity" + op: "Identity" + input: "save/Const" + input: "^save/control_dependency" + input: "^save/MergeV2Checkpoints" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/RestoreV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "a" + } + } + } + } + node { + name: "save/RestoreV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2/tensor_names" + input: "save/RestoreV2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign" + op: "Assign" + input: "a" + input: "save/RestoreV2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@a" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_1/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "b" + } + } + } + } + node { + name: "save/RestoreV2_1/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_1" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_1/tensor_names" + input: "save/RestoreV2_1/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_1" + op: "Assign" + input: "b" + input: "save/RestoreV2_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@b" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "c" + } + } + } + } + node { + name: "save/RestoreV2_2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_2/tensor_names" + input: "save/RestoreV2_2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_2" + op: "Assign" + input: "c" + input: "save/RestoreV2_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@c" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/restore_shard" + op: "NoOp" + input: "^save/Assign" + input: "^save/Assign_1" + input: "^save/Assign_2" + } + node { + name: "save/restore_all" + op: "NoOp" + input: "^save/restore_shard" + } + versions { + producer: 23 + } + } + saver_def { + filename_tensor_name: "save/Const:0" + save_tensor_name: "save/Identity:0" + restore_op_name: "save/restore_all" + max_to_keep: 5 + sharded: true + keep_checkpoint_every_n_hours: 10000.0 + version: V2 + } + collection_def { + key: "asset_filepaths" + value { + node_list { + value: "Const:0" + } + } + } + collection_def { + key: "legacy_init_op" + value { + node_list { + value: "group_deps" + } + } + } + collection_def { + key: "saved_model_assets" + value { + any_list { + value { + type_url: "type.googleapis.com/tensorflow.AssetFileDef" + value: "\n\t\n\007Const:0\022\007foo.txt" + } + } + } + } + collection_def { + key: "trainable_variables" + value { + bytes_list { + value: "\n\003a:0\022\010a/Assign\032\010a/read:0" + value: "\n\003b:0\022\010b/Assign\032\010b/read:0" + value: "\n\003c:0\022\010c/Assign\032\010c/read:0" + } + } + } + collection_def { + key: "variables" + value { + bytes_list { + value: "\n\003a:0\022\010a/Assign\032\010a/read:0" + value: "\n\003b:0\022\010b/Assign\032\010b/read:0" + value: "\n\003c:0\022\010c/Assign\032\010c/read:0" + } + } + } + signature_def { + key: "classify_x2_to_y3" + value { + inputs { + key: "inputs" + value { + name: "x2:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + outputs { + key: "scores" + value { + name: "y3:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + method_name: "tensorflow/serving/classify" + } + } + signature_def { + key: "classify_x_to_y" + value { + inputs { + key: "inputs" + value { + name: "tf_example:0" + dtype: DT_STRING + tensor_shape { + unknown_rank: true + } + } + } + outputs { + key: "scores" + value { + name: "y:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + method_name: "tensorflow/serving/classify" + } + } + signature_def { + key: "regress_x2_to_y3" + value { + inputs { + key: "inputs" + value { + name: "x2:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + outputs { + key: "outputs" + value { + name: "y3:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + method_name: "tensorflow/serving/regress" + } + } + signature_def { + key: "regress_x_to_y" + value { + inputs { + key: "inputs" + value { + name: "tf_example:0" + dtype: DT_STRING + tensor_shape { + unknown_rank: true + } + } + } + outputs { + key: "outputs" + value { + name: "y:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + method_name: "tensorflow/serving/regress" + } + } + signature_def { + key: "regress_x_to_y2" + value { + inputs { + key: "inputs" + value { + name: "tf_example:0" + dtype: DT_STRING + tensor_shape { + unknown_rank: true + } + } + } + outputs { + key: "outputs" + value { + name: "y2:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + method_name: "tensorflow/serving/regress" + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "x" + value { + name: "x:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + outputs { + key: "y" + value { + name: "y:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/variables/variables.data-00000-of-00001 new file mode 100755 index 0000000000000000000000000000000000000000..15b75d6ef6bffc336d138d923badb3928b8c4c13 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/variables/variables.index new file mode 100755 index 0000000000000000000000000000000000000000..7ec9fb4fe2dd21d0a6c324aecd7658fc37cf2326 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_forward_compatibility/00000123/variables/variables.index differ diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index fc5c6ce58d95b4ad06e48d40369740102efe0a66..ae22f7edc423247b34895411d19d7a3c21f86d4f 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -164,10 +164,6 @@ string RewriteWithName(const string& name, string code, // Generate methods for args (inputs). Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, const CompileResult& compile_result, string* methods) { - *methods += R"( - void** args() { return args_; } - const void *const *args() const { return args_; } -)"; size_t num_args = ps.parameters_size(); if (compile_result.has_context_arg) { // If the compiled function needs a XlaLocalRuntimeContext* arg, it's @@ -184,21 +180,21 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites)); const string code = R"( void set_arg{{NAME}}_data(void* data) { - args_[{{I}}] = data; + set_arg_data({{I}}, data); } {{TYPE}}* arg{{NAME}}_data() { - return static_cast<{{TYPE}}*>(args_[{{I}}]); + return static_cast<{{TYPE}}*>(arg_data({{I}})); } {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) { return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - args_[{{I}}])){{INDICES}}; + arg_data({{I}}))){{INDICES}}; } const {{TYPE}}* arg{{NAME}}_data() const { - return static_cast(args_[{{I}}]); + return static_cast(arg_data({{I}})); } const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const { return (*static_cast( - args_[{{I}}])){{INDICES}}; + arg_data({{I}}))){{INDICES}}; } )"; *methods += RewriteWithName(strings::StrCat(i), code, rewrites); @@ -213,74 +209,33 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, Status GenResultMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, string* methods) { if (ps.result().element_type() != xla::TUPLE) { - // Non-tuple (i.e. single-result) case. - if (config.fetch_size() != 1) { - return errors::InvalidArgument( - "non-tuple result implies 1 fetch, but got ", config.fetch_size(), - " fetches"); - } - *methods += R"( - void** results() { return temps_ + kResultIndex; } - const void *const *results() const { return temps_ + kResultIndex; } -)"; - std::vector> rewrites; - TF_RETURN_IF_ERROR(AddRewritesForShape(0, ps.result(), &rewrites)); - const string code = R"( - {{TYPE}}* result{{NAME}}_data() { - return static_cast<{{TYPE}}*>(temps_[kResultIndex]); - } - {{TYPE}}& result{{NAME}}({{DIM_VARS}}) { - return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - temps_[kResultIndex])){{INDICES}}; - } - const {{TYPE}}* result{{NAME}}_data() const { - return static_cast(temps_[kResultIndex]); - } - const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const { - return (*static_cast( - temps_[kResultIndex])){{INDICES}}; + // The XlaCompiler we use to build the xla computation always generates a + // tuple result, and we rely on this to simplify code generation. + return errors::Internal("codegen requires the XLA result to be a tuple"); } -)"; - *methods += RewriteWithName("0", code, rewrites); - if (!config.fetch(0).name().empty()) { - *methods += RewriteWithName("_" + config.fetch(0).name(), code, rewrites); - } - return Status::OK(); - } - // Tuple (i.e. multi-result) case. if (config.fetch_size() != ps.result().tuple_shapes_size()) { return errors::InvalidArgument("mismatch between fetch_size(", config.feed_size(), ") and tuple_size(", ps.result().tuple_shapes_size(), ")"); } - *methods += R"( - void** results() { - return static_cast(temps_[kResultIndex]); - } - const void *const *results() const { - return static_cast(temps_[kResultIndex]); - } -)"; for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { std::vector> rewrites; TF_RETURN_IF_ERROR( AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites)); string code = R"( {{TYPE}}* result{{NAME}}_data() { - return static_cast<{{TYPE}}*>( - static_cast(temps_[kResultIndex])[{{I}}]); + return static_cast<{{TYPE}}*>(result_data({{I}})); } {{TYPE}}& result{{NAME}}({{DIM_VARS}}) { return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - static_cast(temps_[kResultIndex])[{{I}}])){{INDICES}}; + result_data({{I}}))){{INDICES}}; } const {{TYPE}}* result{{NAME}}_data() const { - return static_cast<{{TYPE}}*>( - static_cast(temps_[kResultIndex])[{{I}}]); + return static_cast(result_data({{I}})); } const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const { return (*static_cast( - static_cast(temps_[kResultIndex])[{{I}}])){{INDICES}}; + result_data({{I}}))){{INDICES}}; } )"; *methods += RewriteWithName(strings::StrCat(i), code, rewrites); @@ -291,6 +246,84 @@ Status GenResultMethods(const tf2xla::Config& config, return Status::OK(); } +// Generates code implementing {Arg,Result}Names(), where T is one of +// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string +// literal in the array, with nullptr terminating the array. +template +string GenNameToIndexCode(const T& entries, bool generate) { + // No need for a static array if we're not supposed to generate the data. + if (!generate) { + return "{\n return nullptr;\n }"; + } + // Determine when to stop. We stop emitting string literals after the last + // non-empty name. + int end = entries.size(); + for (int i = entries.size() - 1; i >= 0; --i) { + if (!entries[i].name().empty()) { + break; + } + end = i; + } + // Emit string literals up to the last non-empty name. + string code = "{\n static const char* kNames[] = {"; + for (int i = 0; i < end; ++i) { + if (i > 0) { + code += ", "; + } + code += "\""; + code += entries[i].name(); + code += "\""; + } + if (end > 0) { + code += ", "; + } + code += "nullptr};\n return kNames;\n }"; + return code; +} + +// Converts the given `str` into a comma-separated list of per-character values. +string StringToCharList(const string& str) { + string list; + for (const char c : str) { + if (!list.empty()) { + list += ","; + } + list += strings::StrCat(static_cast(c)); + } + return list; +} + +string GenProgramShapeCode(xla::ProgramShape program_shape, bool generate) { + // No need for any static magic if we're not supposed to generate the data. + if (!generate) { + return "{\n return nullptr;\n }"; + } + // The parameter names are currently meaningless, and redundant with the rest + // of our metadata, so clear them out to avoid confusion and save space. + program_shape.clear_parameter_names(); + const string proto_str = program_shape.SerializeAsString(); + // Embed the program shape as a serialized protobuf in the header file. + // + // TODO(toddw): This strategy will likely fail for larger protobufs, depending + // on the C++ compiler that is used. Figure out another solution if necessary. + string code = R"({ + static const xla::ProgramShape* kShape = []() { + static const char kProto[] = {{{PROTO_LIST}}}; + static constexpr int kProtoSize = {{PROTO_SIZE}}; + xla::ProgramShape* shape = new xla::ProgramShape; + shape->ParseFromArray(kProto, kProtoSize); + return shape; + }(); + return kShape; + })"; + str_util::ReplaceAllPairs( + &code, { + {"{{PROTO_LIST}}", StringToCharList(proto_str)}, + {"{{PROTO_SIZE}}", strings::StrCat(proto_str.size())}, + }); + return code; +} + Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { for (const tf2xla::Feed& feed : config.feed()) { if (!feed.name().empty()) { @@ -336,24 +369,6 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, const size_t temp_bytes_total = total_buffer_bytes(itemp.data(), itemp.size()); - // Create rewrite strings for the optional context arg. - string context_include; - string context_set_arg, context_set_thread_pool, context_member_var; - string run_result = "true"; - string error_msg = "tensorflow::string()"; - if (compile_result.has_context_arg) { - // NOTE: Extra spaces and newlines are used to ensure nice formatting. - context_include = - "#include " - "\"tensorflow/compiler/tf2xla/" - "xla_local_runtime_context.h\"\n"; - context_set_arg = " args_[kNumArgs-1] = &context_;\n"; - context_set_thread_pool = " context_.thread_pool = pool;\n"; - context_member_var = " tensorflow::XlaLocalRuntimeContext context_;\n"; - run_result = "!context_.error"; - error_msg = "context_.error_msg"; - } - // Create rewrite strings for namespace start and end. string ns_start; for (const string& n : opts.namespaces) { @@ -366,6 +381,19 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, ns_end += strings::StrCat("} // end namespace ", n, "\n"); } + // Generate metadata. + const string arg_names_code = + GenNameToIndexCode(config.feed(), opts.gen_name_to_index); + const string result_names_code = + GenNameToIndexCode(config.fetch(), opts.gen_name_to_index); + const string include_xla_data_proto = + opts.gen_program_shape + ? + R"(#include "tensorflow/compiler/xla/xla_data.pb.h")" + : ""; + const string program_shape_code = + GenProgramShapeCode(ps, opts.gen_program_shape); + // Use a poor-man's text templating mechanism; first populate the full header // with placeholder tokens, and then rewrite the tokens with real values. *header = @@ -380,22 +408,23 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, #ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) #define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) -{{CONTEXT_INCLUDE}} -#include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/core/platform/macros.h" +{{INCLUDE_XLA_DATA_PROTO}} +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" namespace Eigen { struct ThreadPoolDevice; } +namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( - void* result, xla::ExecutableRunOptions* run_options, - void** args, void** temps); + void* result, const xla::ExecutableRunOptions* run_options, + const void** args, void** temps); {{NS_START}} // {{CLASS}} represents a computation previously specified in a -// TensorFlow graph, now compiled into executable code. Usage example: +// TensorFlow graph, now compiled into executable code. This extends the generic +// XlaCompiledCpuFunction class with statically type-safe arg and result +// methods. Usage example: // // {{CLASS}} computation; // // ...set args using computation.argN methods @@ -411,9 +440,9 @@ extern "C" void {{ENTRY}}( // buffer allocation strategy. // // Under the default allocation strategy, this class is thread-compatible: -// o Calls to non-const methods require exclusive access to the object. -// o Concurrent calls to const methods are OK, if those calls are made while -// it is guaranteed that no thread may call a non-const method. +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while it +// is guaranteed that no thread may call a non-const method. // // The logical function signature is: // {{PROGRAM_SHAPE}} @@ -423,7 +452,7 @@ extern "C" void {{ENTRY}}( // arg bytes aligned: {{ARG_BYTES_ALIGNED}} // temp bytes total: {{TEMP_BYTES_TOTAL}} // temp bytes aligned: {{TEMP_BYTES_ALIGNED}} -class {{CLASS}} { +class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = {{ARG_NUM}}; @@ -434,47 +463,31 @@ class {{CLASS}} { return kArgSizes; } - // AllocMode controls the buffer allocation mode. - enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, - - // Only allocate result and temp buffers. - // Use set_argN_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, - }; - - {{CLASS}}(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) { - if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { - alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - ArgSizes(), kNumArgs, args_, false /* annotate_initialized */); - } -{{CONTEXT_SET_ARG}} - alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - TempSizes(), kNumTemps, temps_, true /* annotate_initialized */); - } - - ~{{CLASS}}() { - tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); - tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); - } - - // Sets the thread pool to use during the Run call. - {{CLASS}}& set_thread_pool(const Eigen::ThreadPoolDevice* pool) { - run_options_.set_intra_op_thread_pool(pool); -{{CONTEXT_SET_THREAD_POOL}} - return *this; - } - - // Runs the computation, with inputs read from arg buffers, and outputs - // written to result buffers. Returns true on success and false on failure. - bool Run() { - {{ENTRY}}(temps_[kResultIndex], &run_options_, args_, temps_); - return {{RUN_RESULT}}; - } - - // Returns the error message from the previous failed Run call. - tensorflow::string error_msg() const { return {{ERROR_MSG}}; } + // Returns static data used to create an XlaCompiledCpuFunction. + static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() { + static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ + XlaCompiledCpuFunction::StaticData* data = + new XlaCompiledCpuFunction::StaticData; + data->raw_function = {{ENTRY}}; + data->arg_sizes = ArgSizes(); + data->num_args = kNumArgs; + data->temp_sizes = TempSizes(); + data->num_temps = kNumTemps; + data->result_index = kResultIndex; + data->requires_runtime_context = {{HAS_CONTEXT_ARG}}; + data->arg_names = StaticArgNames(); + data->result_names = StaticResultNames(); + data->program_shape = StaticProgramShape(); + return data; + }(); + return *kStaticData; + } + + {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} + + {{CLASS}}(const {{CLASS}}&) = delete; + {{CLASS}}& operator=(const {{CLASS}}&) = delete; // Arg methods for managing input buffers. Buffers are in row-major order. // There is a set of methods for each positional argument, with the following @@ -493,10 +506,6 @@ class {{CLASS}} { // Returns a reference to the value of type T for positional argument N, // with dim indices specifying which value. No bounds checking is performed // on dim indices. - // - // void** args() - // Returns an array of argument buffers, where args()[N] is the buffer for - // positional argument N. {{METHODS_ARG}} // Result methods for managing output buffers. Buffers are in row-major order. @@ -511,10 +520,6 @@ class {{CLASS}} { // with dim indices specifying which value. No bounds checking is performed // on dim indices. // - // void** results() - // Returns an array of result buffers, where results()[N] is the buffer for - // positional result N. - // // Unlike the arg methods, there is no set_resultN_data method. The result // buffers are managed internally, and may change after each call to Run. {{METHODS_RESULT}} @@ -522,7 +527,7 @@ class {{CLASS}} { private: // Number of result and temporary buffers for the compiled computation. static constexpr size_t kNumTemps = {{TEMP_NUM}}; - // The 0-based index of the result in the temporary buffers. + // The 0-based index of the result tuple in the temporary buffers. static constexpr size_t kResultIndex = {{RESULT_INDEX}}; // Byte size of each result / temporary buffer. There are kNumTemps entries. @@ -531,14 +536,14 @@ class {{CLASS}} { return kTempSizes; } - void* args_[kNumArgs]; - void* temps_[kNumTemps]; - void* alloc_args_ = nullptr; - void* alloc_temps_ = nullptr; - xla::ExecutableRunOptions run_options_; -{{CONTEXT_MEMBER_VAR}} + // Array of names of each positional argument, terminated by nullptr. + static const char** StaticArgNames() {{ARG_NAMES_CODE}} + + // Array of names of each positional result, terminated by nullptr. + static const char** StaticResultNames() {{RESULT_NAMES_CODE}} - TF_DISALLOW_COPY_AND_ASSIGN({{CLASS}}); + // Shape of the args and results. + static const xla::ProgramShape* StaticProgramShape() {{PROGRAM_SHAPE_CODE}} }; {{NS_END}} @@ -550,22 +555,22 @@ class {{CLASS}} { const std::vector> rewrites = { {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, + {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, {"{{CLASS}}", opts.class_name}, - {"{{CONTEXT_INCLUDE}}\n", context_include}, - {"{{CONTEXT_MEMBER_VAR}}\n", context_member_var}, - {"{{CONTEXT_SET_ARG}}\n", context_set_arg}, - {"{{CONTEXT_SET_THREAD_POOL}}\n", context_set_thread_pool}, {"{{ENTRY}}", compile_result.entry_point}, - {"{{ERROR_MSG}}", error_msg}, + {"{{HAS_CONTEXT_ARG}}", + compile_result.has_context_arg ? "true" : "false"}, + {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, + {"{{PROGRAM_SHAPE_CODE}}", program_shape_code}, {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, - {"{{RUN_RESULT}}", run_result}, + {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 740edd1e83410ad0d3b854adbec20fb1cab88440..76dd0cc3cf9470a1beb2a4725724f640aecfec7f 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -34,6 +34,12 @@ struct HeaderOpts { // Namespaces specifies a list of C++ namespaces to add to the generated // header. If empty, all symbols will be in the global namespace. std::vector namespaces; + + // If true, generate name-to-index data for Lookup{Arg,Result}Index methods. + bool gen_name_to_index = false; + + // If true, generate program shape data for the ProgramShape method. + bool gen_program_shape = false; }; // GenerateHeader uses the meta-information from compile_result to generate a diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 98cbd67e53432e7c131c2daa27e86e3a613161a1..0f6114666fcc89c631434527d2ae8c92c039ffea 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -127,6 +127,8 @@ TEST(GenerateHeader, Golden) { HeaderOpts opts; opts.class_name = "MyClass"; opts.namespaces = {"foo", "bar"}; + opts.gen_name_to_index = true; + opts.gen_program_shape = true; tf2xla::Config config; tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("feed0"); @@ -145,7 +147,8 @@ TEST(GenerateHeader, Golden) { xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), xla::ShapeUtil::MakeOpaqueShape(), }, - xla::ShapeUtil::MakeShape(xla::U32, {5, 6})); + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); compile_result.has_context_arg = true; compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 01963c6df4682ec8c23a93201d7fbbab63558060..65f342ce27ef09092f252f791973f245a8cdd6f3 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -9,24 +9,25 @@ #ifndef TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard) #define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard) -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" -#include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" namespace Eigen { struct ThreadPoolDevice; } +namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( - void* result, xla::ExecutableRunOptions* run_options, - void** args, void** temps); + void* result, const xla::ExecutableRunOptions* run_options, + const void** args, void** temps); namespace foo { namespace bar { // MyClass represents a computation previously specified in a -// TensorFlow graph, now compiled into executable code. Usage example: +// TensorFlow graph, now compiled into executable code. This extends the generic +// XlaCompiledCpuFunction class with statically type-safe arg and result +// methods. Usage example: // // MyClass computation; // // ...set args using computation.argN methods @@ -42,19 +43,19 @@ namespace bar { // buffer allocation strategy. // // Under the default allocation strategy, this class is thread-compatible: -// o Calls to non-const methods require exclusive access to the object. -// o Concurrent calls to const methods are OK, if those calls are made while -// it is guaranteed that no thread may call a non-const method. +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while it +// is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> u32[5,6] +// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> (u32[5,6]) // // Memory stats: // arg bytes total: 104 // arg bytes aligned: 128 // temp bytes total: 126 // temp bytes aligned: 224 -class MyClass { +class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = 3; @@ -65,47 +66,31 @@ class MyClass { return kArgSizes; } - // AllocMode controls the buffer allocation mode. - enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, - - // Only allocate result and temp buffers. - // Use set_argN_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, - }; - - MyClass(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) { - if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { - alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - ArgSizes(), kNumArgs, args_, false /* annotate_initialized */); - } - args_[kNumArgs-1] = &context_; - alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - TempSizes(), kNumTemps, temps_, true /* annotate_initialized */); - } - - ~MyClass() { - tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); - tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); - } - - // Sets the thread pool to use during the Run call. - MyClass& set_thread_pool(const Eigen::ThreadPoolDevice* pool) { - run_options_.set_intra_op_thread_pool(pool); - context_.thread_pool = pool; - return *this; - } - - // Runs the computation, with inputs read from arg buffers, and outputs - // written to result buffers. Returns true on success and false on failure. - bool Run() { - entry_point(temps_[kResultIndex], &run_options_, args_, temps_); - return !context_.error; - } - - // Returns the error message from the previous failed Run call. - tensorflow::string error_msg() const { return context_.error_msg; } + // Returns static data used to create an XlaCompiledCpuFunction. + static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() { + static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ + XlaCompiledCpuFunction::StaticData* data = + new XlaCompiledCpuFunction::StaticData; + data->raw_function = entry_point; + data->arg_sizes = ArgSizes(); + data->num_args = kNumArgs; + data->temp_sizes = TempSizes(); + data->num_temps = kNumTemps; + data->result_index = kResultIndex; + data->requires_runtime_context = true; + data->arg_names = StaticArgNames(); + data->result_names = StaticResultNames(); + data->program_shape = StaticProgramShape(); + return data; + }(); + return *kStaticData; + } + + MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} + + MyClass(const MyClass&) = delete; + MyClass& operator=(const MyClass&) = delete; // Arg methods for managing input buffers. Buffers are in row-major order. // There is a set of methods for each positional argument, with the following @@ -124,66 +109,59 @@ class MyClass { // Returns a reference to the value of type T for positional argument N, // with dim indices specifying which value. No bounds checking is performed // on dim indices. - // - // void** args() - // Returns an array of argument buffers, where args()[N] is the buffer for - // positional argument N. - - void** args() { return args_; } - const void *const *args() const { return args_; } void set_arg0_data(void* data) { - args_[0] = data; + set_arg_data(0, data); } float* arg0_data() { - return static_cast(args_[0]); + return static_cast(arg_data(0)); } float& arg0(size_t dim0, size_t dim1) { return (*static_cast( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } const float* arg0_data() const { - return static_cast(args_[0]); + return static_cast(arg_data(0)); } const float& arg0(size_t dim0, size_t dim1) const { return (*static_cast( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } void set_arg_myfeed_data(void* data) { - args_[0] = data; + set_arg_data(0, data); } float* arg_myfeed_data() { - return static_cast(args_[0]); + return static_cast(arg_data(0)); } float& arg_myfeed(size_t dim0, size_t dim1) { return (*static_cast( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } const float* arg_myfeed_data() const { - return static_cast(args_[0]); + return static_cast(arg_data(0)); } const float& arg_myfeed(size_t dim0, size_t dim1) const { return (*static_cast( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } void set_arg1_data(void* data) { - args_[1] = data; + set_arg_data(1, data); } tensorflow::int64* arg1_data() { - return static_cast(args_[1]); + return static_cast(arg_data(1)); } tensorflow::int64& arg1(size_t dim0, size_t dim1) { return (*static_cast( - args_[1]))[dim0][dim1]; + arg_data(1)))[dim0][dim1]; } const tensorflow::int64* arg1_data() const { - return static_cast(args_[1]); + return static_cast(arg_data(1)); } const tensorflow::int64& arg1(size_t dim0, size_t dim1) const { return (*static_cast( - args_[1]))[dim0][dim1]; + arg_data(1)))[dim0][dim1]; } // Result methods for managing output buffers. Buffers are in row-major order. @@ -198,50 +176,43 @@ class MyClass { // with dim indices specifying which value. No bounds checking is performed // on dim indices. // - // void** results() - // Returns an array of result buffers, where results()[N] is the buffer for - // positional result N. - // // Unlike the arg methods, there is no set_resultN_data method. The result // buffers are managed internally, and may change after each call to Run. - void** results() { return temps_ + kResultIndex; } - const void *const *results() const { return temps_ + kResultIndex; } - tensorflow::uint32* result0_data() { - return static_cast(temps_[kResultIndex]); + return static_cast(result_data(0)); } tensorflow::uint32& result0(size_t dim0, size_t dim1) { return (*static_cast( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } const tensorflow::uint32* result0_data() const { - return static_cast(temps_[kResultIndex]); + return static_cast(result_data(0)); } const tensorflow::uint32& result0(size_t dim0, size_t dim1) const { return (*static_cast( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } tensorflow::uint32* result_myfetch_data() { - return static_cast(temps_[kResultIndex]); + return static_cast(result_data(0)); } tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) { return (*static_cast( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } const tensorflow::uint32* result_myfetch_data() const { - return static_cast(temps_[kResultIndex]); + return static_cast(result_data(0)); } const tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) const { return (*static_cast( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } private: // Number of result and temporary buffers for the compiled computation. static constexpr size_t kNumTemps = 6; - // The 0-based index of the result in the temporary buffers. + // The 0-based index of the result tuple in the temporary buffers. static constexpr size_t kResultIndex = 5; // Byte size of each result / temporary buffer. There are kNumTemps entries. @@ -250,14 +221,29 @@ class MyClass { return kTempSizes; } - void* args_[kNumArgs]; - void* temps_[kNumTemps]; - void* alloc_args_ = nullptr; - void* alloc_temps_ = nullptr; - xla::ExecutableRunOptions run_options_; - tensorflow::XlaLocalRuntimeContext context_; + // Array of names of each positional argument, terminated by nullptr. + static const char** StaticArgNames() { + static const char* kNames[] = {"myfeed", nullptr}; + return kNames; + } + + // Array of names of each positional result, terminated by nullptr. + static const char** StaticResultNames() { + static const char* kNames[] = {"myfetch", nullptr}; + return kNames; + } - TF_DISALLOW_COPY_AND_ASSIGN(MyClass); + // Shape of the args and results. + static const xla::ProgramShape* StaticProgramShape() { + static const xla::ProgramShape* kShape = []() { + static const char kProto[] = {10,12,16,11,26,2,1,2,42,4,10,2,1,0,10,12,16,5,26,2,3,4,42,4,10,2,1,0,10,2,16,14,18,16,16,13,34,12,16,8,26,2,5,6,42,4,10,2,1,0}; + static constexpr int kProtoSize = 50; + xla::ProgramShape* shape = new xla::ProgramShape; + shape->ParseFromArray(kProto, kProtoSize); + return shape; + }(); + return kShape; + } }; } // end namespace bar diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc index 4e3998b68293aa47f028c745cea36a8c533d237d..5aff10346fa368f214436d1d0837505ffbbc771e 100644 --- a/tensorflow/compiler/aot/flags.cc +++ b/tensorflow/compiler/aot/flags.cc @@ -64,6 +64,10 @@ void AppendMainFlags(std::vector* flag_list, MainFlags* flags) { "namespaces are given, within the global namespace."}, {"out_object", &flags->out_object, "Output object file name."}, {"out_header", &flags->out_header, "Output header file name."}, + {"gen_name_to_index", &flags->gen_name_to_index, + "Generate name-to-index data for Lookup{Arg,Result}Index methods."}, + {"gen_program_shape", &flags->gen_program_shape, + "Generate program shape data for the ProgramShape method."}, }; flag_list->insert(flag_list->end(), tmp.begin(), tmp.end()); } diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h index e11a0173fa0035237915be80cf66b2bfca0f9b12..3246dbf95c8a60130af91bc3891b15829aa5e638 100644 --- a/tensorflow/compiler/aot/flags.h +++ b/tensorflow/compiler/aot/flags.h @@ -37,6 +37,10 @@ struct MainFlags { string cpp_class; string out_object; string out_header; + + // C++ codegen options + bool gen_name_to_index = false; + bool gen_program_shape = false; }; // Appends to flag_list a tensorflow::Flag for each field in MainFlags. diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index b0b1213a847c586259e3b8f1d175f089c3961dfd..7dfd49cc3b92f83fd64ca62bd2230938ce2d0a65 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -132,6 +132,7 @@ tf_library( cpp_class = "MatMulAndAddComp", graph = "test_graph_tfmatmulandadd.pb", tags = ["manual"], + tfcompile_flags = "--gen_name_to_index --gen_program_shape", ) tf_library( @@ -156,6 +157,8 @@ tf_cc_test( ":test_graph_tfmatmul", ":test_graph_tfmatmulandadd", ":test_graph_tfsplits", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 07562e59c8dac942f41af69c289c9f29a9767a6a..6b037f276ad1d6771b904bb970f45f32ae9531b8 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -178,16 +180,6 @@ TEST(TFCompileTest, Gather) { } EXPECT_EQ(gather_const.result0_data(), gather.results()[0]); } - - // Bad indices returns an error. - { - const float params[4] = {1, 2, 3, 4}; - std::copy(params + 0, params + 4, gather.arg0_data()); - const int32 indices[2] = {1, 4}; - std::copy(indices + 0, indices + 2, gather.arg1_data()); - EXPECT_FALSE(gather.Run()); - EXPECT_EQ(gather.error_msg(), "Invalid index for gather"); - } } TEST(TFCompileTest, MatMul2) { @@ -421,6 +413,59 @@ TEST(TFCompileTest, Splits) { EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4); } +TEST(TFCompileTest, LookupNameIndex) { + // add doesn't have any names defined in its config. + AddComp add; + EXPECT_FALSE(add.HasNameIndices()); + + // muladd has names defined for all feeds and fetches. + MatMulAndAddComp muladd; + EXPECT_TRUE(muladd.HasNameIndices()); + + EXPECT_EQ(muladd.LookupArgIndex("x"), 0); + EXPECT_EQ(muladd.LookupArgIndex("y"), 1); + EXPECT_EQ(muladd.LookupArgIndex(""), -1); + EXPECT_EQ(muladd.LookupArgIndex("x_hold"), -1); + EXPECT_EQ(muladd.LookupArgIndex("y_hold"), -1); + EXPECT_EQ(muladd.LookupArgIndex("x_y_prod"), -1); + EXPECT_EQ(muladd.LookupArgIndex("x_y_sum"), -1); + + EXPECT_EQ(muladd.LookupResultIndex("x_y_prod"), 0); + EXPECT_EQ(muladd.LookupResultIndex("x_y_sum"), 1); + EXPECT_EQ(muladd.LookupResultIndex(""), -1); + EXPECT_EQ(muladd.LookupResultIndex("x"), -1); + EXPECT_EQ(muladd.LookupResultIndex("y"), -1); + EXPECT_EQ(muladd.LookupResultIndex("x_hold"), -1); + EXPECT_EQ(muladd.LookupResultIndex("y_hold"), -1); +} + +TEST(TFCompileTest, ProgramShape) { + using xla::ShapeUtil; + const xla::Shape f32_2x2 = ShapeUtil::MakeShape(xla::F32, {2, 2}); + + // add doesn't have the program shape defined. + AddComp add; + ASSERT_TRUE(add.ProgramShape() == nullptr); + + // muladd has the program shape defined. + MatMulAndAddComp muladd; + const xla::ProgramShape* muladd_shape = muladd.ProgramShape(); + ASSERT_TRUE(muladd_shape != nullptr); + ASSERT_EQ(muladd_shape->parameters_size(), 2); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2)); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2)); + + const xla::Shape& muladd_result = muladd_shape->result(); + ASSERT_EQ(muladd_result.element_type(), xla::TUPLE); + ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2); + const xla::Shape& muladd_result0 = + ShapeUtil::GetTupleElementShape(muladd_result, 0); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_result0, f32_2x2)); + const xla::Shape& muladd_result1 = + ShapeUtil::GetTupleElementShape(muladd_result, 1); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2)); +} + } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 608d461a4cebba92944b8c56fd295394ba6e59b0..4888760acd45f2789193884407b3742a5e9683ec 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -167,6 +167,8 @@ def tf_library(name, graph, config, # The cc_library rule packaging up the header and object file, and needed # kernel implementations. + need_xla_data_proto = (tfcompile_flags and + tfcompile_flags.find("--gen_program_shape") != -1) native.cc_library( name=name, srcs=[object_file], @@ -177,14 +179,13 @@ def tf_library(name, graph, config, # These deps are required by all tf_library targets even if # include_standard_runtime_deps is False. Without them, the # generated code will fail to compile. - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", "//tensorflow/core:framework_lite", - ] + (include_standard_runtime_deps and [ + ] + (need_xla_data_proto and [ + # If we're generating the program shape, we must depend on the proto. + "//tensorflow/compiler/xla:xla_data_proto", + ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually needed. - "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int32", - "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int64", "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx", @@ -292,7 +293,6 @@ def tf_library(name, graph, config, tags=tags, ) - def target_llvm_triple(): """Returns the target LLVM triple to be used for compiling the target.""" # TODO(toddw): Add target_triple for other targets. For details see: diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index cc499c3284689182638665e0884f6377d8d9f3ee..6ab3d474187c7df2131f94c9f42f0d0f2f9d99d7 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -94,6 +94,8 @@ Status Main(const MainFlags& flags) { TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object, StringPiece(obj.data(), obj.size()))); HeaderOpts header_opts; + header_opts.gen_name_to_index = flags.gen_name_to_index; + header_opts.gen_program_shape = flags.gen_program_shape; if (flags.cpp_class.empty()) { return errors::InvalidArgument("Must specify --cpp_class"); } diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index bf63b7e5016c33b47a8839e03894930203ab0654..bf7d9cf14d10f41aa48ea594a8d63db97b9973e1 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -33,6 +33,7 @@ cc_library( deps = [ ":xla_cpu_device", ":xla_cpu_jit", + "//tensorflow/compiler/plugin", ] + if_cuda_is_configured([ ":xla_gpu_device", ":xla_gpu_jit", diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index ded01ff6530745dafc22a259576e1f6816d20c02..27c5da08c112664d361b5f969d100eed7b9df65c 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -110,7 +110,7 @@ xla::StatusOr XlaAllocator::Allocate( Status XlaAllocator::RegisterArgument(const Tensor* t) { void* data = - reinterpret_cast(const_cast(t->tensor_data().data())); + reinterpret_cast(const_cast(t->tensor_data().data())); TF_RET_CHECK(data != nullptr); tensors_[data] = *t; return Status::OK(); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index db2ed16f95e71f5055d0cb2fce3cda8a7f2976d9..78d0aa86a8fae9a0c6035bdc579ef800337df917 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -560,6 +560,7 @@ Status MarkForCompilationPass::RunImpl( name = strings::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); + VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; } } diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 57b9d6b56bca23e94dc172dce2412ed151643318..2e33fdca657f470270cb25fa2ac661a441b70552 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -39,9 +39,9 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, - DEVICE_CPU_XLA_JIT, options, name_prefix, - &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create( + "Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, name_prefix, + /*register_device_for_compilation=*/true, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 888461611fee6dd78d086ca9da67da40d515bca1..7ccea58f6e9aa467402ef78cd6fa89f3feb60e6f 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" @@ -107,18 +108,21 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, - const string& name_prefix, std::unique_ptr* device) { + const string& name_prefix, bool register_device_for_compilation, + std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; - // These are no-ops if they have already been done previously for - // this device_name/compilation_device_name pair. - XlaOpRegistry::DeviceRegistration registration; - registration.compilation_device_name = jit_device_name; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; - registration.compile_resource_ops = true; - XlaOpRegistry::RegisterCompilationDevice(device_name, registration); + if (register_device_for_compilation) { + // These are no-ops if they have already been done previously for + // this device_name/compilation_device_name pair. + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = jit_device_name; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + XlaOpRegistry::RegisterCompilationDevice(device_name, registration); + } auto platform = se::MultiPlatformManager::PlatformWithName(platform_name); if (!platform.ok()) { @@ -158,7 +162,8 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, const Metadata** metadata) { - XlaDevice* xla_device = dynamic_cast(ctx->device()); + XlaDevice* xla_device = + dynamic_cast(ctx->device()->UnderlyingDevice()); if (xla_device == nullptr) { return errors::Internal( "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(), diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 0d90b8b692896d8addf5ffead3980a5bf640c85c..d2ec38293c429f04f088bf3726ba97eb4e4b0dba 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -74,6 +74,7 @@ class XlaDevice : public LocalDevice { static Status Create(const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, const string& name_prefix, + bool register_device_for_compilation, std::unique_ptr* device); XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 4474d8f4eb06afa78ea36332a8cc58f9d240c1b0..5233665ec283a770117aa5bec1c0d01f17a04526 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -39,9 +39,9 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr device; - Status status = - XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, - name_prefix, &device); + Status status = XlaDevice::Create( + "CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, + /*register_device_for_compilation=*/true, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 4e4cbe200a21b6584f0fefb4cf43874fb213e244..2614deefd8823dcb8f38e9e22ae4e78145d0d96a 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -42,9 +42,9 @@ Status XlaInterpreterDeviceFactory::CreateDevices( (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, - DEVICE_INTERPRETER_XLA_JIT, options, - name_prefix, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create( + "Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, + options, name_prefix, /*register_device_for_compilation=*/true, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f0886721546bba3ace76e50608dc4fe61416da5c --- /dev/null +++ b/tensorflow/compiler/plugin/BUILD @@ -0,0 +1,42 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Configuration file for an XLA plugin. + + please don't check in changes to this file. to prevent changes appearing + in git status, use: + + git update-index --assume-unchanged tensorflow/compiler/plugin/BUILD + + To add additional devices to the XLA subsystem, add targets to the + dependency list in the 'plugin' target. For instance: + + deps = ["//tensorflow/compiler/plugin/example:plugin_lib"], + + ** Please don't remove this file - it is supporting some 3rd party plugins ** +""" + +licenses(["notice"]) + +package( + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "plugin", + deps = [ + #"//tensorflow/compiler/plugin/example:example_lib", + ], +) diff --git a/tensorflow/compiler/plugin/README.md b/tensorflow/compiler/plugin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9dd0d2bdab5e2c990fd547cef4b657253c545715 --- /dev/null +++ b/tensorflow/compiler/plugin/README.md @@ -0,0 +1,16 @@ +3rd party XLA devices +--------------------- + +This directory is intended as a place for 3rd party XLA devices which are _not_ +integrated into the public repository. + +By adding entries to the BUILD target in this directory, a third party device +can be included as a dependency of the JIT subsystem. + +For integration into the unit test system, see the files: + +- tensorflow/compiler/tests/plugin.bzl +- tensorflow/compiler/xla/tests/plugin.bzl + + +- diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 5a46eb0bb79397ccb2e155f4a1bfa9bd6013b0d0..0eed475140c72034ad664b3ae03f09944d92473f 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -97,9 +97,13 @@ tf_xla_py_test( size = "small", srcs = ["binary_ops_test.py"], shard_count = 5, + tags = [ + "optonly", # Times out frequently in fastbuild mode. + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", + "//tensorflow/python:bitwise_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", @@ -179,6 +183,7 @@ tf_xla_py_test( "noasan", "nomsan", "notsan", + "optonly", # Times out frequently in fastbuild mode. ], deps = [ ":xla_test", @@ -208,11 +213,6 @@ tf_xla_py_test( name = "slice_ops_test", size = "small", srcs = ["slice_ops_test.py"], - # TODO(b/62962492): Test fails with assertion error. - tags = [ - "manual", - "notap", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -509,12 +509,8 @@ tf_xla_py_test( tf_xla_py_test( name = "gather_test", - size = "small", + size = "medium", srcs = ["gather_test.py"], - # Gather needs CustomCall on CPU, which is not available in normal - # (not precompiled) TensorFlow. The flag below excludes the CPU - # backend. - disabled_backends = "cpu", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -576,6 +572,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow_opensource", "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index f3ea57596e9a488c4930312c1ba347b2be2c7f24..44b32b1668443f65ee0a47766683e2730d64b929 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops @@ -45,6 +46,10 @@ class BinaryOpsTest(XLATestCase): equality_test = self.assertAllClose equality_test(result, expected, rtol=1e-3) + def _testSymmetricBinary(self, op, a, b, expected, equality_test=None): + self._testBinary(op, a, b, expected, equality_test) + self._testBinary(op, b, a, expected, equality_test) + def ListsAreClose(self, result, expected, rtol): """Tests closeness of two lists of floats.""" self.assertEqual(len(result), len(expected)) @@ -193,6 +198,16 @@ class BinaryOpsTest(XLATestCase): np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -1, 0, -4, 2], dtype=dtype)) + self._testSymmetricBinary( + bitwise_ops.bitwise_and, + np.array([0b1, 0b101, 0b1000], dtype=dtype), + np.array([0b0, 0b101, 0b1001], dtype=dtype), + expected=np.array([0b0, 0b101, 0b1000], dtype=dtype)) + self._testSymmetricBinary( + bitwise_ops.bitwise_or, + np.array([0b1, 0b101, 0b1000], dtype=dtype), + np.array([0b0, 0b101, 0b1001], dtype=dtype), + expected=np.array([0b1, 0b101, 0b1001], dtype=dtype)) def testNumericOps(self): for dtype in self.numeric_types: @@ -790,28 +805,30 @@ class BinaryOpsTest(XLATestCase): def testSplit(self): for dtype in self.numeric_types: - self._testBinary( - lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x), - np.int32(0), - np.array([[[1], [2]], [[3], [4]], [[5], [6]]], - dtype=dtype), - expected=[ - np.array([[[1], [2]]], dtype=dtype), - np.array([[[3], [4]]], dtype=dtype), - np.array([[[5], [6]]], dtype=dtype), - ], - equality_test=self.ListsAreClose) - - self._testBinary( - lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x), - np.int32(1), - np.array([[[1], [2]], [[3], [4]], [[5], [6]]], - dtype=dtype), - expected=[ - np.array([[[1]], [[3]], [[5]]], dtype=dtype), - np.array([[[2]], [[4]], [[6]]], dtype=dtype), - ], - equality_test=self.ListsAreClose) + for axis in [0, -3]: + self._testBinary( + lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x), + np.int32(axis), + np.array([[[1], [2]], [[3], [4]], [[5], [6]]], + dtype=dtype), + expected=[ + np.array([[[1], [2]]], dtype=dtype), + np.array([[[3], [4]]], dtype=dtype), + np.array([[[5], [6]]], dtype=dtype), + ], + equality_test=self.ListsAreClose) + + for axis in [1, -2]: + self._testBinary( + lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x), + np.int32(axis), + np.array([[[1], [2]], [[3], [4]], [[5], [6]]], + dtype=dtype), + expected=[ + np.array([[[1]], [[3]], [[5]]], dtype=dtype), + np.array([[[2]], [[4]], [[6]]], dtype=dtype), + ], + equality_test=self.ListsAreClose) def testTile(self): for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index d2a4e4bbd49cd1a78d80163bdbf147c34a455e38..4b81c1d7abcb89ef8f776137d1c7d57481c82515 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -24,8 +24,12 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import flags from tensorflow.python.platform import test +FLAGS = flags.FLAGS + _TEST_TYPES = [dtypes.float32] @@ -81,8 +85,31 @@ class GatherTest(xla_test.XLATestCase): expected = np.take(params_np, [0, 1, 0, 2], axis=axis) self.assertAllEqual(expected, gather_val) + def testSimpleTwoD32_Int64Indices(self): + if np.int64 not in self.int_types: + return + + with self.test_session() as session, self.test_scope(): + data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], + [12, 13, 14]]) + # The indices must be in bounds for any axis. + indices_np = np.array([0, 1, 0, 2]) + for dtype in _TEST_TYPES: + for axis in 0, 1, -1: + params_np = self._buildParams(data, dtype) + params = array_ops.placeholder(dtype=dtype) + indices = array_ops.placeholder(dtype=dtypes.int64) + gather_t = array_ops.gather(params, indices, axis=axis) + gather_val = session.run( + gather_t, feed_dict={ + params: params_np, + indices: indices_np + }) + expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + self.assertAllEqual(expected, gather_val) + def testHigherRank(self): - # Check that scalar and empty indices shapes work as well. + """Check that scalar and empty indices shapes work as well.""" shape = (2, 1, 3, 2) for indices_shape in (), (0,), (2, 0), (2, 3): for dtype in _TEST_TYPES: @@ -98,5 +125,67 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual(gather_np, gather_value) -if __name__ == "__main__": +class GatherBenchmark(test.Benchmark): + """Microbenchmarks for the gather op.""" + + def _benchmarkGather(self, name, axis, gather_indices, use_xla_jit): + + def BuilderFn(): + inputs = variables.Variable( + array_ops.zeros([100, 100, 10, 100, 50], dtype=dtypes.float32), + dtype=dtypes.float32, + name='input') + indices = variables.Variable( + gather_indices, dtype=dtypes.int32, name='indices') + gather_t = array_ops.gather(inputs, indices, axis=axis) + return '%s.axis%d' % (name, axis), [gather_t] + + xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu') + + def _benchmarkSliceGather(self, axis, use_xla_jit): + """Benchmarks a gather op that's really a dynamic slice.""" + self._benchmarkGather('slice_gather', axis, [1], use_xla_jit) + + def _benchmarkNontrivialGather(self, axis, use_xla_jit): + self._benchmarkGather('nontrivial_gather', axis, [9, 1, 0, 2] * 4, + use_xla_jit) + + def benchmarkSliceGatherAxis0(self): + self._benchmarkSliceGather(axis=0, use_xla_jit=False) + + def benchmarkSliceGatherAxis0XLA(self): + self._benchmarkSliceGather(axis=0, use_xla_jit=True) + + def benchmarkSliceGatherAxis1(self): + self._benchmarkSliceGather(axis=1, use_xla_jit=False) + + def benchmarkSliceGatherAxis1XLA(self): + self._benchmarkSliceGather(axis=1, use_xla_jit=True) + + def benchmarkSliceGatherAxis4(self): + self._benchmarkSliceGather(axis=4, use_xla_jit=False) + + def benchmarkSliceGatherAxis4XLA(self): + self._benchmarkSliceGather(axis=4, use_xla_jit=True) + + def benchmarkNontrivialGatherAxis0(self): + self._benchmarkNontrivialGather(axis=0, use_xla_jit=False) + + def benchmarkNontrivialGatherAxis0XLA(self): + self._benchmarkNontrivialGather(axis=0, use_xla_jit=True) + + def benchmarkNontrivialGatherAxis1(self): + self._benchmarkNontrivialGather(axis=1, use_xla_jit=False) + + def benchmarkNontrivialGatherAxis1XLA(self): + self._benchmarkNontrivialGather(axis=1, use_xla_jit=True) + + def benchmarkNontrivialGatherAxis4(self): + self._benchmarkNontrivialGather(axis=4, use_xla_jit=False) + + def benchmarkNontrivialGatherAxis4XLA(self): + self._benchmarkNontrivialGather(axis=4, use_xla_jit=True) + + +if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 11914080eccbf3506e6a17e243bf9f8ba1cbb812..2d8236e2cbdfafb35626cd582ee39b1f917aec7f 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -21,15 +21,12 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.compiler import jit -from tensorflow.core.framework import function_pb2 -from tensorflow.core.framework import node_def_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl @@ -118,31 +115,13 @@ class JitLaunchTest(test.TestCase): def testNoOutputs(self): with session_lib.Session() as sess: - # Build a function with a single Const node, whose output is ignored. - fdef = function_pb2.FunctionDef() - fdef.signature.name = "KernelWithNoOutputs" - node = node_def_pb2.NodeDef() - node.op = "Const" - node.name = "ignored" - node.attr["dtype"].type = dtypes.int32.as_datatype_enum - tensor = tensor_util.make_tensor_proto([0], dtype=dtypes.int32, shape=[]) - node.attr["value"].tensor.CopyFrom(tensor) - fdef.node_def.extend([node]) # Check that calling the result as a compiled kernel doesn't crash. @function.Defun(compiled=True) def KernelWithNoOutputs(): - return constant_op.constant(100) - - # Hack to override the definition. By accessing .definition, we - # force the _DefinedFunction initialized internally. Then, we - # replace it's internal FunctionDef proto. We do this hack here - # because one typically can't construct KernelWithNoOutputs - # function via Defun decorator directly. - _ = KernelWithNoOutputs.definition - foo = KernelWithNoOutputs - foo._definition = fdef - call = KernelWithNoOutputs() + a = constant_op.constant(100) # pylint: disable=unused-variable + + call = KernelWithNoOutputs() # pylint: disable=assignment-from-no-return sess.run(call, {}) def testAliasing(self): diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index 2660e1d5728caf88e2b9ae73b3e3fde2aee71ed8..ae60d78f1a8dd898c5428a82be2196b52d4638d8 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import unittest + import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase @@ -29,7 +31,7 @@ from tensorflow.python.platform import googletest class NAryOpsTest(XLATestCase): - def _testNAry(self, op, args, expected): + def _testNAry(self, op, args, expected, equality_fn=None): with self.test_session() as session: with self.test_scope(): placeholders = [ @@ -39,7 +41,17 @@ class NAryOpsTest(XLATestCase): feeds = {placeholders[i]: args[i] for i in range(0, len(args))} output = op(placeholders) result = session.run(output, feeds) - self.assertAllClose(result, expected, rtol=1e-3) + if not equality_fn: + equality_fn = self.assertAllClose + equality_fn(result, expected, rtol=1e-3) + + def _nAryListCheck(self, results, expected, **kwargs): + self.assertEqual(len(results), len(expected)) + for (r, e) in zip(results, expected): + self.assertAllClose(r, e, **kwargs) + + def _testNAryLists(self, op, args, expected): + self._testNAry(op, args, expected, equality_fn=self._nAryListCheck) def testFloat(self): self._testNAry(math_ops.add_n, @@ -56,6 +68,24 @@ class NAryOpsTest(XLATestCase): np.array([42], dtype=np.float32)], expected=np.array([48], dtype=np.float32)) + @unittest.skip("IdentityN is temporarily CompilationOnly as workaround") + def testIdentityN(self): + self._testNAryLists(array_ops.identity_n, + [np.array([[1, 2, 3]], dtype=np.float32)], + expected=[np.array([[1, 2, 3]], dtype=np.float32)]) + self._testNAryLists(array_ops.identity_n, + [np.array([[1, 2], [3, 4]], dtype=np.float32), + np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)], + expected=[ + np.array([[1, 2], [3, 4]], dtype=np.float32), + np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)]) + self._testNAryLists(array_ops.identity_n, + [np.array([[1], [2], [3], [4]], dtype=np.int32), + np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)], + expected=[ + np.array([[1], [2], [3], [4]], dtype=np.int32), + np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)]) + def testConcat(self): self._testNAry( lambda x: array_ops.concat(x, 0), [ diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index b3ec9424c75ed155b4e0307f77eab2ba23618134..5129171cd42b09f31bb1a4da02ffc6be6093f6f1 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -899,7 +899,7 @@ TEST_F(OpTest, ApproximateEqual) { TEST_F(OpTest, ArgMax) { Repeatedly([this]() { - std::vector dims = RandomDims(1, 5); + std::vector dims = RandomDims(1, 5, 1); int num_dims = dims.size(); int reduce_dim = std::uniform_int_distribution(-num_dims, num_dims)(generator()); @@ -1168,6 +1168,28 @@ TEST_F(OpTest, BiasAddV1) { }); } +TEST_F(OpTest, BitwiseAnd) { + Repeatedly([this]() { + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseAnd") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, BitwiseOr) { + Repeatedly([this]() { + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseOr") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); + }); +} + TEST_F(OpTest, BroadcastArgs) { Repeatedly([this]() { // TODO(phawkins): only int32 seems to be implemented in Tensorflow. @@ -1729,6 +1751,14 @@ TEST_F(OpTest, GreaterEqual) { }); } +TEST_F(OpTest, Invert) { + Repeatedly([this]() { + DataType type = DT_INT32; + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Invert").RandomInput(type).Attr("T", type)); + }); +} + TEST_F(OpTest, L2Loss) { Repeatedly([this]() { DataType type = DT_FLOAT; @@ -2653,7 +2683,8 @@ TEST_F(OpTest, Split) { std::vector dims = RandomDims(1); std::uniform_int_distribution ud; int32 dim = std::uniform_int_distribution( - 0, static_cast(dims.size()) - 1)(generator()); + -static_cast(dims.size()), + static_cast(dims.size()) - 1)(generator()); int n = std::uniform_int_distribution(1, 5)(generator()); // Ensure 'dim' is evenly divisible by 'n'. dims[dim] /= n; diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index e0a7bf3e2c8a2836d8d11bfc4bda0ec4c77daefe..71221b284d5b7ff4e3c259cafef9166dc2ef246c 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -26,6 +26,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -309,11 +310,6 @@ class UnaryOpsTest(XLATestCase): [0.032058604, 0.087144323, 0.23688284, 0.64391428]], dtype=dtype)) - self._assertOpOutputMatchesExpected( - nn_ops.softplus, - np.array([[-2, 0, 8]], dtype=dtype), - expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype)) - self._assertOpOutputMatchesExpected( nn_ops.softsign, np.array([[-2, -1, 0, 1, 2]], dtype=dtype), @@ -332,6 +328,13 @@ class UnaryOpsTest(XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype)) + def testIntOps(self): + for dtype in self.int_types: + self._assertOpOutputMatchesExpected( + bitwise_ops.invert, + np.array([0, -1, 1, 16, 42], dtype=dtype), + expected=np.array([-1, 0, -2, -17, -43], dtype=dtype)) + def testNumericOps(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( @@ -543,6 +546,26 @@ class UnaryOpsTest(XLATestCase): [[9, 10, 11, 12], [13, 14, 15, 16]]]], dtype=dtype)) + def _assertSoftplusMatchesExpected(self, features, dtype): + features = np.array(features, dtype=dtype) + zero = np.asarray(0).astype(dtype) + expected = np.logaddexp(zero, features) + self._assertOpOutputMatchesExpected( + nn_ops.softplus, features, expected=expected) + + def testSoftplus(self): + for dtype in self.float_types: + self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype) + self._assertSoftplusMatchesExpected( + [[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype) + log_eps = np.log(np.finfo(dtype).eps) + one = dtype(1) + ten = dtype(10) + self._assertSoftplusMatchesExpected([ + log_eps, log_eps - one, log_eps + one, log_eps - ten, + log_eps + ten, -log_eps, -log_eps - one, -log_eps + one, + -log_eps - ten, -log_eps + ten], dtype) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 0769b13718821f7528b7e2620e50787e55aa20f6..3c94bcafc1d19b1bc54887e6f2c25b1886be646e 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -58,6 +58,42 @@ cc_library( ], ) +cc_library( + name = "xla_compiled_cpu_function", + srcs = ["xla_compiled_cpu_function.cc"], + hdrs = ["xla_compiled_cpu_function.h"], + visibility = ["//visibility:public"], + deps = [ + # Keep dependencies to a minimum here; this library is used in every AOT + # binary produced by tfcompile. + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + ], +) + +cc_library( + name = "xla_jit_compiled_cpu_function", + srcs = ["xla_jit_compiled_cpu_function.cc"], + hdrs = ["xla_jit_compiled_cpu_function.h"], + visibility = ["//visibility:public"], + deps = [ + ":tf2xla", + ":tf2xla_proto", + ":xla_compiled_cpu_function", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service/cpu:cpu_executable", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "xla_compiler", srcs = [ @@ -67,11 +103,13 @@ cc_library( "xla_helpers.cc", "xla_op_kernel.cc", "xla_op_registry.cc", + "graph_compiler.cc", "xla_cpu_backend.cc", ] + if_cuda_is_configured([ "xla_gpu_backend.cc", ]), hdrs = [ + "graph_compiler.h", "xla_compilation_device.h", "xla_compiler.h", "xla_context.h", @@ -82,6 +120,7 @@ cc_library( visibility = [":friends"], deps = [ ":common", + ":const_analysis", ":dump_graph", ":functionalize_control_flow", "//tensorflow/compiler/xla:literal_util", @@ -178,6 +217,25 @@ tf_cc_test( ], ) +tf_cc_test( + name = "xla_jit_compiled_cpu_function_test", + srcs = ["xla_jit_compiled_cpu_function_test.cc"], + deps = [ + ":tf2xla_proto", + ":xla_jit_compiled_cpu_function", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "xla_compiler_test", srcs = ["xla_compiler_test.cc"], @@ -198,6 +256,7 @@ tf_cc_test( "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -292,6 +351,7 @@ cc_library( hdrs = ["functionalize_control_flow.h"], deps = [ "//tensorflow/compiler/jit:graph_to_functiondef", + "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:functional_ops", "//tensorflow/compiler/xla:status_macros", @@ -299,6 +359,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:lib", ], ) @@ -316,6 +377,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/cc:functional_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:ops", diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 1c7a2046aa549beb2de58d21f517363d4fe8aea7..35b6960a98cda1bf098f3e01cac3df8173bdc729 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -17,15 +17,19 @@ limitations under the License. #include #include +#include #include #include #include "tensorflow/compiler/jit/graph_to_functiondef.h" +#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/lib/gtl/optional.h" namespace tensorflow { @@ -70,11 +74,24 @@ struct Frame { std::unordered_set nodes; }; +// Returns a textual representation of the names of the nodes in the input. +template +string NodesToString(const T& nodes) { + return strings::StrCat("{", + str_util::Join(nodes, ",", + [](string* output, const Node* node) { + strings::StrAppend(output, + node->name()); + }), + "}"); +} + // Copies a subgraph from `graph` to `output` by performing a reverse DFS // starting at nodes in vector `stack`. // `node_map` is a vector indexed by source node ID to dest nodes. // Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` -// before the traversal clients can cut the graph. Returns an error if the +// before the traversal clients can cut the graph. If a frame is provided (frame +// != nullptr), then this functions will return an error if the // traversal leaves 'frame'; the client must add enough nodes to `node_map` to // cut the graph and prevent the traversal from escaping. // @@ -84,25 +101,26 @@ struct Frame { // taking from the Switch node was not necessarily the first output, but _Arg // nodes only have one output. By adding the Switch node to `squash_src_outputs` // we rewrite the src_output of the corresponding edge to be 0. -Status CopySubgraph(const Graph& graph, const Frame& frame, +Status CopySubgraph(const Graph& graph, const Frame* frame, std::vector stack, const std::vector& squash_src_outputs, std::vector* node_map, Graph* output) { + VLOG(3) << "Stack: " << NodesToString(stack); std::vector visited(graph.num_node_ids(), false); while (!stack.empty()) { Node* n = stack.back(); stack.pop_back(); - VLOG(3) << "Copying node " << n->name(); + VLOG(5) << "Copying node " << n->name(); if (visited[n->id()]) continue; visited[n->id()] = true; for (const Edge* e : n->in_edges()) { Node* src = e->src(); - if (frame.nodes.find(src) == frame.nodes.end()) { + if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) { // We traversed out of the loop frame, without encountering a cut node. - return errors::Internal("Graph traversal of loop frame ", frame.name, + return errors::Internal("Graph traversal of loop frame ", frame->name, " escaped frame at ", src->name(), " without encountering an argument node."); } @@ -119,27 +137,31 @@ Status CopySubgraph(const Graph& graph, const Frame& frame, return Status::OK(); } -Status BuildArgNode(Graph* graph, DataType type, int index, Node** arg_node) { +xla::StatusOr AddNode(const NodeDef& node_def, Graph* graph) { + Status status; + Node* inserted_node = graph->AddNode(node_def, &status); + if (!status.ok()) { + return status; + } + return inserted_node; +} + +xla::StatusOr BuildArgNode(Graph* graph, DataType type, int index) { NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat("_Arg", index), kArgOp); + NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); builder.Attr("T", type); builder.Attr("index", index); TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); - Status status; - *arg_node = graph->AddNode(arg_def, &status); - return status; + return AddNode(arg_def, graph); } -Status BuildRetvalNode(Graph* graph, DataType type, int index, - Node** retval_node) { +xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { NodeDef ret_def; ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat("_Retval", index)); + ret_def.set_name(strings::StrCat(kRetValOp, index)); AddNodeAttr("T", type, &ret_def); AddNodeAttr("index", index, &ret_def); - Status status; - *retval_node = graph->AddNode(ret_def, &status); - return status; + return AddNode(ret_def, graph); } // Builds a graph for the loop condition. @@ -157,9 +179,8 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame, for (int i = 0; i < frame->args.size(); ++i) { const Arg& arg = frame->args[i]; - Node* arg_node; - TF_RETURN_IF_ERROR( - BuildArgNode(output, arg.enter->input_type(0), i, &arg_node)); + TF_ASSIGN_OR_RETURN(Node * arg_node, + BuildArgNode(output, arg.enter->input_type(0), i)); if (arg.is_loop_invariant) { node_map[arg.enter->id()] = arg_node; } else { @@ -169,16 +190,14 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame, // Build a Retval node for the loop condition. The LoopCond nodes are always // boolean because of the type constraints on the LoopCond op. - TF_RETURN_IF_ERROR( - BuildRetvalNode(output, DT_BOOL, 0, &node_map[frame->loop_cond->id()])); + TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()], + BuildRetvalNode(output, DT_BOOL, 0)); // Performs a reverse DFS, copying nodes and edges to the output graph. // The _Arg and _Retval nodes were added unconditionally above, so we are // guaranteed to get the correct function signature. - TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, {frame->loop_cond}, - squash_src_outputs, &node_map, output)); - - return Status::OK(); + return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs, + &node_map, output); } // Builds a graph for the loop body. @@ -202,8 +221,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, DataType dtype = arg.enter->input_type(0); arg_types->push_back(dtype); - Node* arg_node; - TF_RETURN_IF_ERROR(BuildArgNode(output, dtype, i, &arg_node)); + + TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i)); if (dtype == DT_RESOURCE) { // The convention of the XLA bridge is that resource variable arguments @@ -213,8 +232,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, TF_RET_CHECK(arg.is_loop_invariant); node_map[arg.enter->id()] = arg_node; } else { - Node* retval_node; - TF_RETURN_IF_ERROR(BuildRetvalNode(output, dtype, i, &retval_node)); + TF_ASSIGN_OR_RETURN(Node * retval_node, + BuildRetvalNode(output, dtype, i)); if (arg.is_loop_invariant) { // Argument is loop-invariant. Forward it from the Arg to the Retval. @@ -237,7 +256,7 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, // Performs a reverse DFS, copying nodes and edges to the output graph. // The _Arg and _Retval nodes were added unconditionally above, so we are // guaranteed to get the correct function signature. - TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, std::move(next_iterations), + TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), squash_src_outputs, &node_map, output)); return Status::OK(); @@ -396,10 +415,6 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, arg.exit = edge->dst(); } } - if (arg.exit == nullptr) { - return errors::InvalidArgument("Missing Exit successor to ", - arg.switch_node->name()); - } } } @@ -450,12 +465,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, } builder.Input(inputs); TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); - - Status status; - Node* while_node = graph->AddNode(while_def, &status); - if (!status.ok()) { - return status; - } + TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph)); // Copies edges to the Enter nodes and from the Exit nodes onto the While. for (int i = 0; i < frame->args.size(); ++i) { @@ -469,16 +479,21 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, } if (!arg.is_loop_invariant) { - std::vector edges(arg.exit->out_edges().begin(), - arg.exit->out_edges().end()); - for (const Edge* edge : edges) { - Node* dst = edge->dst(); - int dst_input = edge->dst_input(); - graph->RemoveEdge(edge); - - int src_output = - dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; - graph->AddEdge(while_node, src_output, dst, dst_input); + // Add output edges if the output of the loop is consumed. + if (arg.exit != nullptr) { + std::vector edges(arg.exit->out_edges().begin(), + arg.exit->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + + if (dst_input == Graph::kControlSlot) { + graph->AddControlEdge(while_node, dst); + } else { + graph->AddEdge(while_node, i, dst, dst_input); + } + } } } } @@ -488,6 +503,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, for (Node* node : frame->nodes) { graph->RemoveNode(node); } + frame->nodes.clear(); frame->parent->nodes.insert(while_node); VLOG(2) << "Frame " << frame->name << " after: " @@ -496,13 +512,863 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, return Status::OK(); } +class FunctionalizeCond { + public: + // Identifies the connected parts of the tf.Cond. + struct ClusterHandle { + explicit ClusterHandle(int representative = -1) + : representative(representative) {} + + bool operator==(const ClusterHandle& other) const { + return representative == other.representative; + } + + bool operator!=(const ClusterHandle& other) const { + return !(*this == other); + } + + bool operator<(const ClusterHandle& other) const { + return representative < other.representative; + } + + bool operator>(const ClusterHandle& other) const { + return representative > other.representative; + } + + string ToString() const { + return strings::StrCat("Cluster_", representative); + } + + // Vector of UnionFind indexable by ClusterHandle and Node*. + struct Vector { + explicit Vector(size_t size) : clusters(size) {} + + UnionFind& at(const ClusterHandle& cluster) { + return clusters.at(cluster.representative); + } + + UnionFind& at(const Node* node) { + return clusters.at(node->id()); + } + + UnionFind& operator[](const Node* node) { + return clusters.at(node->id()); + } + + size_t size() const { return clusters.size(); } + + void resize(size_t count) { return clusters.resize(count); } + + private: + std::vector> clusters; + }; + + private: + int representative; + }; + + // Represents a node in the clustered graph consisting of switch_nodes, + // merge_nodes as well as the edges into and out of this node to other + // Clusters. Each Cluster corresponds to a ClusterHandle and has a + // corresponding representative. + struct Cluster { + std::unordered_set switch_nodes; + std::unordered_set merge_nodes; + std::unordered_set in_nodes; + std::unordered_set out_nodes; + + // A member of the ClusterHandle corresponding to this Cluster. + ClusterHandle representative; + bool visited = false; + }; + + // Represent the clustered graph as map from cluster representative to + // Cluster. + using ClusteredGraph = std::map; + + // The arguments and condition of a XlaIf. The arguments are ordered by node + // id in the original graph. + struct CondArgs { + struct CondCmp { + bool operator()(const Node* lhs, const Node* rhs) const { + bool lhs_is_resource = + lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; + bool rhs_is_resource = + rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; + return std::tie(lhs_is_resource, lhs->name()) < + std::tie(rhs_is_resource, rhs->name()); + } + }; + Node* conditional = nullptr; + std::set args; + }; + + static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); + + private: + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) + : clusters_(graph->num_node_ids()), library_(library), graph_(graph) {} + + // Returns a vector of Merge nodes from the clustered graph where the nodes + // are sorted by the number of switch nodes minus number of merge nodes + // from a root of the clustered graph to the given Merge node, with ties + // broken by the representative of the Cluster. + std::vector> SortedMergeNodes(); + + // Returns whether the graph has no conditionals. + bool NoConditionals() const { return merge_nodes_.empty(); } + + // Construct the clustered graph by creating nodes for each cluster and the + // connections between the clusters. Switch and Merge nodes partition + // clusters, so iterate over those. Note: a Cluster may have neither a + // Merge or Switch but will have an in/out edge from a Cluster that has. + void CreateClusters(); + + // Creates the clustered graph by identifying all the edges between different + // clusters and collecting all switch and merge nodes that correspond to a + // cluster. + void CreateClusteredGraph(); + + // If `from` and `to` correspond to different clusters, then merge the nodes + // in the clustered graph corresponding to `from` and `to`. + // + // If `remove_from_graph` is specified then the `from` node is also removed + // from the clustered graph post contracting the edge. + void ContractEdge(Cluster* from, Cluster* to, bool remove_from_graph = false); + + // Converts a Merge node to a XlaIf. This encapsulates the process of + // extracting the bodies needed for the then and else branch, creates a XlaIf + // node, removing the nodes of the branches from the graph and replacing the + // merge node with a XlaIf. + Status ConvertMergeToXlaIf(Cluster* merge_cluster); + + // Removes a Switch cluster feeding directly into a Merge cluster by removing + // the Switch and Merge nodes and collapsing into a single cluster. + Status RemoveTrivialMerge(Cluster* merge_cluster); + + // Returns the switch cluster corresponding to the merge node. This function + // only returns the switch cluster in the simple case where we have a switch + // node is the entry of a diamond corresponding to a conditional: + // + // Switch + // / \ + // Branch Branch + // \ / + // merge_cluster + // + // Note: either of the branches may be empty. The case where both branches are + // empty is handled by RemoveTrivialMerge. + gtl::optional GetSwitchCluster(const Cluster& merge_cluster); + + // Determines the arguments needed as input to the Merge cluster originating + // from the Switch cluster. + xla::StatusOr DetermineCondArgs(const Cluster& merge_cluster, + const Cluster& switch_cluster); + + // Builds a XlaIfOp to replace the Merge node with. + xla::StatusOr BuildAndAddXlaIfOp(const CondArgs& cond_args, + const Cluster& merge_cluster, + const std::vector& outputs); + + // Extracts a function body corresponding to the given input edge of the merge + // node. + Status ExtractBody(const CondArgs& cond_args, const Cluster& merge_cluster, + const std::vector& outputs, int input_edge, + Graph* body); + + // Adds all the input edges to `if_node` corresponding to the arguments. + Status AddInputEdges(const CondArgs& cond_args, Node* if_node); + + // Adds all output edges from the `if_node`. + Status AddOutputEdges(const std::vector& outputs, Node* if_node); + + // Removes all nodes from the graph that are part of cluster. + void RemoveClusterNodes(Cluster* cluster); + + // Removes all argument nodes that are unused. + template + void RemoveUnusedArgs(const T& args); + + // Removes all Merge nodes in merge_cluster. + void RemoveMergeNodes(Cluster* merge_cluster); + + // Returns the representative member of the corresponding cluster. + ClusterHandle Representative(const Node* node) { + return clusters_.at(node).Get(); + } + + ClusteredGraph clustered_graph_; + ClusterHandle::Vector clusters_; + std::unordered_set merge_nodes_; + std::unordered_set switch_nodes_; + FunctionLibraryDefinition* library_; + Graph* graph_; +}; + +std::ostream& operator<<(std::ostream& os, + const FunctionalizeCond::ClusterHandle& c) { + os << c.ToString(); + return os; +} + +// Returns a dot representation of the clustered graph showing the connections +// between the nodes and the nodes in each cluster. +string DebugString(const Graph& graph, + FunctionalizeCond::ClusterHandle::Vector* clusters) { + string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n"; + std::map subgraphs; + for (Node* n : graph.nodes()) { + if (n->IsOp()) { + strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), + " [label=\"", n->name(), "\"];\n"); + } + } + for (auto kv : subgraphs) { + strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n", + "style=filled; color=lightgrey;", "label = \"", + kv.first.ToString(), "\";\n", kv.second, "}\n"); + } + for (Node* n : graph.nodes()) { + if (!n->IsOp()) { + continue; + } + for (Node* in : n->in_nodes()) { + if (in->IsOp()) { + strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); + } + } + } + return strings::StrCat(ret, "}"); +} + +string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { + string ret = "digraph {\ncompound=true;labeljust=\"r\";\n"; + auto name = [](const FunctionalizeCond::Cluster& cluster) { + return cluster.representative.ToString(); + }; + for (auto kv : clustered_graph) { + strings::StrAppend(&ret, kv.first.ToString(), " [label=\"", name(kv.second), + " (", kv.second.switch_nodes.size(), ", ", + kv.second.merge_nodes.size(), ")\"];\n"); + } + for (auto kv : clustered_graph) { + for (auto in : kv.second.in_nodes) { + strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n"); + } + } + return strings::StrCat(ret, "}"); +} + +bool IsDeadSwitch(const Node* node) { + for (const Edge* e : node->out_edges()) { + const Node* dst = e->dst(); + if (!dst->IsIdentity()) { + return false; + } + for (const Edge* ee : dst->out_edges()) { + if (!ee->IsControlEdge() || !ee->dst()->IsSink()) { + return false; + } + } + } + return true; +} + +void FunctionalizeCond::CreateClusters() { + for (Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + if (IsSwitch(node)) { + switch_nodes_.insert(node); + } else if (IsMerge(node)) { + merge_nodes_.insert(node); + } + ClusterHandle& cluster = clusters_.at(node).Get(); + cluster = ClusterHandle(node->id()); + } + + // If there are no Merge nodes, then terminate. + if (merge_nodes_.empty()) { + return; + } + + // Remove all dead Switch nodes. + RemoveUnusedArgs(switch_nodes_); + + // All parent_'s are still nullptr so clusters_ may still be resized. Resize + // conservatively assuming all merge nodes become XlaIf nodes. + clusters_.resize(clusters_.size() + merge_nodes_.size()); + + // Merge a cluster with its input, unless the input is a Switch node or + // the node is a Merge node. + for (const Node* node : graph_->nodes()) { + if (IsMerge(node) || IsSwitch(node) || !node->IsOp()) { + continue; + } + for (const Node* in : node->in_nodes()) { + if (in->IsOp() && !IsSwitch(in) && !IsMerge(in)) { + clusters_.at(node).Merge(&clusters_.at(in)); + } + } + } +} + +void FunctionalizeCond::ContractEdge(Cluster* from, Cluster* to, + bool remove_from_graph) { + VLOG(3) << "ContractEdge from = " << from->representative + << " to = " << to->representative; + if (from->representative == to->representative) { + return; + } + to->merge_nodes.insert(from->merge_nodes.begin(), from->merge_nodes.end()); + from->merge_nodes.clear(); + to->switch_nodes.insert(from->switch_nodes.begin(), from->switch_nodes.end()); + from->switch_nodes.clear(); + + for (Cluster* from_out : from->out_nodes) { + from_out->in_nodes.erase(from); + if (from_out->representative != to->representative) { + from_out->in_nodes.insert(to); + to->out_nodes.insert(from_out); + } + } + from->out_nodes.clear(); + + for (Cluster* from_in : from->in_nodes) { + from_in->out_nodes.erase(from); + if (from_in->representative != to->representative) { + from_in->out_nodes.insert(to); + to->in_nodes.insert(from_in); + } + } + from->in_nodes.clear(); + + to->in_nodes.erase(from); + to->out_nodes.erase(from); + clusters_.at(to->representative).Merge(&clusters_.at(from->representative)); + from->visited = true; + + if (remove_from_graph) { + clustered_graph_.erase(from->representative); + } +} + +void FunctionalizeCond::CreateClusteredGraph() { + auto update_cluster_for_node = [this](Node* node) -> Cluster& { + ClusterHandle repr = Representative(node); + Cluster& cluster_node = clustered_graph_[repr]; + cluster_node.representative = repr; + for (const Node* in : node->in_nodes()) { + ClusterHandle other_repr = Representative(in); + // Skip source, sink and internal edges. + if (!in->IsOp() || other_repr == repr) { + continue; + } + Cluster& cluster_node_in = clustered_graph_[other_repr]; + cluster_node.in_nodes.insert(&cluster_node_in); + cluster_node_in.out_nodes.insert(&cluster_node); + cluster_node_in.representative = other_repr; + } + for (const Node* out : node->out_nodes()) { + ClusterHandle other_repr = Representative(out); + // Skip source, sink and internal edges. + if (!out->IsOp() || other_repr == repr) { + continue; + } + Cluster& cluster_node_out = clustered_graph_[other_repr]; + cluster_node.out_nodes.insert(&cluster_node_out); + cluster_node_out.in_nodes.insert(&cluster_node); + cluster_node_out.representative = other_repr; + } + return cluster_node; + }; + for (Node* node : switch_nodes_) { + update_cluster_for_node(node).switch_nodes.insert(node); + } + for (Node* node : merge_nodes_) { + update_cluster_for_node(node).merge_nodes.insert(node); + } + + // Merge Switch nodes with common predicate. + std::unordered_map> predicate_to_switch; + for (Node* node : switch_nodes_) { + Node* tmp; + TF_CHECK_OK(node->input_node(1, &tmp)); + predicate_to_switch[tmp].push_back(node); + } + for (auto kv : predicate_to_switch) { + Cluster& first = clustered_graph_.at(Representative(kv.second.front())); + for (Node* switch_node : kv.second) { + ClusterHandle handle = Representative(switch_node); + Cluster& cluster = clustered_graph_.at(handle); + ContractEdge(&cluster, &first, /*remove_from_graph=*/true); + } + } + + // Merge Merge nodes with common input together. + for (Node* node : merge_nodes_) { + Cluster& cluster = clustered_graph_.at(Representative(node)); + for (const Node* in : node->in_nodes()) { + if (!in->IsOp()) { + continue; + } + Cluster& cluster_node_in = clustered_graph_.at(Representative(in)); + // ContractEdge can modify out_nodes of cluster_node_in, so traverse + // over out_nodes assuming it does. + for (auto it = cluster_node_in.out_nodes.begin(); + it != cluster_node_in.out_nodes.end();) { + if (!(*it)->merge_nodes.empty()) { + ContractEdge(*it++, &cluster, /*remove_from_graph=*/true); + } else { + ++it; + } + } + } + } + + VLOG(3) << "Graph with clusters: " << DebugString(*graph_, &clusters_); + VLOG(3) << "ClusteredGraph: " << DebugString(clustered_graph_); +} + +gtl::optional FunctionalizeCond::GetSwitchCluster( + const Cluster& merge_cluster) { + VLOG(3) << "GetSwitchCluster for " << merge_cluster.representative; + gtl::optional switch_cluster; + if (merge_cluster.in_nodes.size() > 2) { + return gtl::nullopt; + } + for (Cluster* in : merge_cluster.in_nodes) { + Cluster* cluster = in; + if (in->switch_nodes.empty()) { + if (in->in_nodes.size() != 1) { + return gtl::nullopt; + } + // There is only a single `in` cluster. + cluster = *in->in_nodes.begin(); + } + if (cluster->switch_nodes.empty()) { + return gtl::nullopt; + } + + if (switch_cluster.has_value() && *switch_cluster != cluster) { + return gtl::nullopt; + } else { + switch_cluster = cluster; + } + } + return switch_cluster; +} + +xla::StatusOr FunctionalizeCond::DetermineCondArgs( + const Cluster& merge_cluster, const Cluster& switch_cluster) { + VLOG(2) << "DetermineCondArgs for " << merge_cluster.representative + << " with switch cluster " << switch_cluster.representative; + CondArgs ret; + auto feeds_into_branch_cluster = [&](Node* switch_cluster) { + for (Node* out : switch_cluster->out_nodes()) { + ClusterHandle repr = Representative(out); + if (repr == merge_cluster.representative) { + return true; + } + for (Cluster* in : merge_cluster.in_nodes) { + if (repr == in->representative) { + return true; + } + } + } + return false; + }; + for (Node* switch_cluster_node : switch_cluster.switch_nodes) { + if (!feeds_into_branch_cluster(switch_cluster_node)) { + continue; + } + + Node* tmp; + TF_RETURN_IF_ERROR(switch_cluster_node->input_node(1, &tmp)); + if (ret.conditional == nullptr) { + ret.conditional = tmp; + } else if (ret.conditional != tmp) { + return errors::Unimplemented( + "Switch statements with different conditionals cannot be " + "converted into functional conditional."); + } + ret.args.insert(switch_cluster_node); + } + return ret; +} + +xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( + const CondArgs& cond_args, const Cluster& merge_cluster, + const std::vector& outputs) { + VLOG(2) << "Build if op for " << NodesToString(merge_cluster.merge_nodes) + << " with input " << NodesToString(cond_args.args); + + NodeDef if_def; + // Create a new If node using the name of the merge node. + NodeDefBuilder builder( + strings::StrCat((*merge_cluster.merge_nodes.begin())->name(), "_If"), + "XlaIf"); + string branch[] = {"else_branch", "then_branch"}; + for (int i = 0; i < 2; ++i) { + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + + NameAttrList body_name; + body_name.set_name( + strings::StrCat("_functionalize_if_", branch[i], "_", id)); + auto body = xla::MakeUnique(graph_->op_registry()); + TF_RETURN_IF_ERROR( + ExtractBody(cond_args, merge_cluster, outputs, i, body.get())); + VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); + FunctionDef body_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); + TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef)); + builder.Attr(branch[i], body_name); + } + + // Build input type. + std::vector inputs; + DataTypeVector in_arg_types; + for (const Node* arg : cond_args.args) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + DataType dtype = arg->input_type(0); + inputs.emplace_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), dtype)); + in_arg_types.push_back(dtype); + } + } + builder.Attr("Tin", in_arg_types); + + // Build output type. + DataTypeVector out_type; + for (const Node* merge : merge_cluster.merge_nodes) { + DataType dtype = merge->output_type(0); + out_type.push_back(dtype); + } + builder.Attr("Tout", out_type); + + builder.Attr("Tcond", DT_BOOL); + builder.Device(cond_args.conditional->assigned_device_name()); + // Conditional should be the first input ... + builder.Input(NodeDefBuilder::NodeOut(cond_args.conditional->name(), 0, + cond_args.conditional->output_type(0))); + // ... followed by the other inputs. + builder.Input(inputs); + + TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); + TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_)); + return if_node; +} + +void FunctionalizeCond::RemoveClusterNodes(Cluster* cluster) { + VLOG(3) << "RemoveClusterNodes for " << cluster->representative; + ClusterHandle repr = cluster->representative; + std::deque to_delete; + for (Node* node : graph_->nodes()) { + if (Representative(node) == repr) { + to_delete.push_back(node); + } + } + for (Node* n : to_delete) { + graph_->RemoveNode(n); + } +} + +template +void FunctionalizeCond::RemoveUnusedArgs(const T& args) { + VLOG(2) << "RemoveUnusedArgs among: " << NodesToString(args); + + std::deque to_delete; + for (Node* arg : args) { + if (IsDeadSwitch(arg)) { + to_delete.push_back(arg); + for (Node* n : arg->out_nodes()) { + to_delete.push_back(n); + } + } + } + for (Node* n : to_delete) { + switch_nodes_.erase(n); + auto it = clustered_graph_.find(Representative(n)); + if (it != clustered_graph_.end()) { + it->second.switch_nodes.erase(n); + } + graph_->RemoveNode(n); + } +} + +Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, + const Cluster& merge_cluster, + const std::vector& outputs, + int input_edge, Graph* body) { + VLOG(2) << "ExtractBody for " << merge_cluster.representative + << " along edge " << input_edge; + std::vector squash_src_outputs(graph_->num_node_ids(), false); + std::vector node_map(graph_->num_node_ids(), nullptr); + int arg_count = 0; + for (const auto* arg : cond_args.args) { + DataType dtype = arg->input_type(0); + TF_ASSIGN_OR_RETURN(Node * arg_node, + BuildArgNode(body, dtype, arg_count++)); + node_map.at(arg->id()) = arg_node; + squash_src_outputs.at(arg->id()) = true; + } + + std::vector stack; + stack.reserve(outputs.size()); + for (int j = 0; j < outputs.size(); ++j) { + Node* node = outputs[j]; + TF_ASSIGN_OR_RETURN(node_map.at(node->id()), + BuildRetvalNode(body, node->output_type(0), + /*index=*/j)); + const Edge* in_edge; + TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge)); + Node* in = in_edge->src(); + if (node_map.at(in->id()) == nullptr) { + node_map.at(in->id()) = body->CopyNode(in); + } + + if (cond_args.args.find(in) == cond_args.args.end()) { + body->AddEdge(node_map.at(in->id()), in_edge->src_output(), + node_map.at(node->id()), 0); + } else { + body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0); + // Don't include input nodes that are already just returned in stack. + continue; + } + stack.push_back(in); + } + + return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map, + body); +} + +Status FunctionalizeCond::AddInputEdges(const CondArgs& cond_args, + Node* if_node) { + VLOG(3) << "AddInputEdges for " << if_node->name(); + int i = 0; + graph_->AddEdge(cond_args.conditional, 0, if_node, i++); + for (const Node* arg : cond_args.args) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + graph_->AddControlEdge(in_edge->src(), if_node); + } else { + graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, i++); + } + } + return Status::OK(); +} + +Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, + Node* if_node) { + VLOG(3) << "AddOutputEdges for " << if_node->name(); + for (int i = 0; i < outputs.size(); ++i) { + Node* node = outputs[i]; + std::vector edges(node->out_edges().begin(), + node->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + + if (edge->src_output() > 0) { + return errors::Unimplemented("Output of index (", edge->src_output(), + ") of merge node ", node->name()); + } + graph_->RemoveEdge(edge); + + int src_output = + dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; + graph_->AddEdge(if_node, src_output, dst, dst_input); + } + } + return Status::OK(); +} + +void FunctionalizeCond::RemoveMergeNodes(Cluster* merge_cluster) { + VLOG(3) << "RemoveMergeNodes for " << merge_cluster->representative; + // Remove all merge nodes now dead post extraction of If. + for (auto it = merge_cluster->merge_nodes.begin(); + it != merge_cluster->merge_nodes.end();) { + Node* node = *it; + graph_->RemoveNode(node); + merge_cluster->merge_nodes.erase(*it++); + } +} + +Status FunctionalizeCond::RemoveTrivialMerge(Cluster* merge_cluster) { + Cluster* switch_cluster = *merge_cluster->in_nodes.begin(); + if (switch_cluster->switch_nodes.empty()) { + return errors::FailedPrecondition( + "Not a trivial merge: no Switch node feeding into Merge node"); + } + + for (auto it = merge_cluster->merge_nodes.begin(); + it != merge_cluster->merge_nodes.end();) { + // We have the following structure: + // Op -> Switch -> Merge -> Consumer + // and we want to transform it to: + // Op -> Consumer + Node* merge_node = *it; + Node* switch_node; + const Edge* in = nullptr; + TF_RETURN_IF_ERROR(merge_node->input_node(0, &switch_node)); + TF_RETURN_IF_ERROR(switch_node->input_edge(0, &in)); + for (auto out : merge_node->out_edges()) { + int src_output = out->dst_input() == Graph::kControlSlot + ? Graph::kControlSlot + : in->src_output(); + graph_->AddEdge(in->src(), src_output, out->dst(), out->dst_input()); + } + graph_->RemoveNode(*it++); + } + RemoveUnusedArgs(switch_cluster->switch_nodes); + + return Status::OK(); +} + +Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) { + VLOG(1) << "ConvertMergeToXlaIf for " << merge_cluster->representative; + gtl::optional switch_cluster = GetSwitchCluster(*merge_cluster); + if (!switch_cluster.has_value()) { + return errors::FailedPrecondition( + "Merge cluster was not part of a simple conditional in the clustered " + "graph. Graph nodes in merge cluster ", + NodesToString(merge_cluster->merge_nodes)); + } + TF_ASSIGN_OR_RETURN(auto cond_args, + DetermineCondArgs(*merge_cluster, **switch_cluster)); + + // Sort the outputs by ID to produce more stable output. + std::vector outputs(merge_cluster->merge_nodes.begin(), + merge_cluster->merge_nodes.end()); + std::sort(outputs.begin(), outputs.end(), CondArgs::CondCmp()); + + // Extract bodies and builds a If operator. + TF_ASSIGN_OR_RETURN(Node * if_node, + BuildAndAddXlaIfOp(cond_args, *merge_cluster, outputs)); + TF_RETURN_IF_ERROR(AddInputEdges(cond_args, if_node)); + TF_RETURN_IF_ERROR(AddOutputEdges(outputs, if_node)); + + // Remove the old nodes from the graph_ and contract the edges of the + // clustered graph. + for (auto in : merge_cluster->in_nodes) { + if (in != *switch_cluster) { + RemoveClusterNodes(in); + } + } + RemoveMergeNodes(merge_cluster); + RemoveUnusedArgs(cond_args.args); + auto in_nodes = merge_cluster->in_nodes; + for (auto it = in_nodes.begin(); it != in_nodes.end();) { + ContractEdge(*it++, merge_cluster); + } + ContractEdge(*switch_cluster, merge_cluster); + clusters_[if_node].Get() = ClusterHandle(merge_cluster->representative); + + return Status::OK(); +} + +std::vector> +FunctionalizeCond::SortedMergeNodes() { + VLOG(2) << "ProcessClusteredGraph"; + std::stack> stack; + for (auto& c : clustered_graph_) { + if (c.second.in_nodes.empty()) { + stack.push({0, &c.second}); + } + } + + // Perform a depth-first traversal of the clustered graph computing the + // switch-merge depth. + std::vector> queue; + std::unordered_set visited; + while (!stack.empty()) { + Cluster* n = stack.top().second; + size_t depth = stack.top().first; + stack.pop(); + + auto inserted = visited.insert(n); + if (!inserted.second) { + continue; + } + + size_t new_depth = depth; + if (!n->merge_nodes.empty()) { + queue.emplace_back(depth, n); + --new_depth; + } + if (!n->switch_nodes.empty()) { + ++new_depth; + } + for (Cluster* e : n->out_nodes) { + stack.emplace(new_depth, e); + } + } + + // Sort in reverse order of switch-merge depth with ties broken by the + // ClusterHandle. + std::sort(queue.begin(), queue.end(), + [](const std::pair& lhs, + const std::pair& rhs) { + return std::tie(lhs.first, lhs.second->representative) > + std::tie(rhs.first, rhs.second->representative); + }); + + return queue; +} + +Status FunctionalizeCond::Functionalize(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(1) << "FunctionalizeCond::Functionalize"; + FunctionalizeCond fc(graph, library); + fc.CreateClusters(); + if (fc.NoConditionals()) { + return Status::OK(); + } + fc.CreateClusteredGraph(); + + auto queue = fc.SortedMergeNodes(); + for (auto it = queue.begin(); it != queue.end();) { + Cluster* merge_cluster = (*it).second; + ++it; + if (merge_cluster->in_nodes.size() == 1) { + TF_RETURN_IF_ERROR(fc.RemoveTrivialMerge(merge_cluster)); + } else { + TF_RETURN_IF_ERROR(fc.ConvertMergeToXlaIf(merge_cluster)); + } + + // Contract newly Merge free merge_cluster with incoming nodes without + // Switch or Merge nodes. + std::vector in_nodes(merge_cluster->in_nodes.begin(), + merge_cluster->in_nodes.end()); + for (auto in : in_nodes) { + if (in->merge_nodes.empty() && in->switch_nodes.empty()) { + fc.ContractEdge(in, merge_cluster); + } + } + } + + if (!fc.switch_nodes_.empty()) { + return errors::Internal( + "Failed to functionalize control flow with Switch nodes remaining: ", + NodesToString(fc.switch_nodes_)); + } + return Status::OK(); +} + } // namespace // Transformation that converts Tensorflow's graph control flow constructs into // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { - VLOG(2) << "FunctionalizeControlFlow: " + VLOG(2) << "FunctionalizeControlFlow (initial): " << dump_graph::DumpGraphToFile("functionalize_initial", *graph); // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this @@ -577,6 +1443,13 @@ Status FunctionalizeControlFlow(Graph* graph, } } + // FunctionalizeControlFlow is invoked for every function, so the loops's + // bodies and conditionals that were extracted into functions will be handled + // in successive invocations. + TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library)); + + VLOG(2) << "FunctionalizeControlFlow (final): " + << dump_graph::DumpGraphToFile("functionalize_final", *graph); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 1535dc80b0ccdba38c57b534ed7473fc8632e33f..4d4ee3054c2914bb614bf75f7a51be8f6292683e 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -23,7 +23,6 @@ namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While // operators, suitable for XLA compilation. -// TODO(b/36470387): add support for conditionals. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 914c8999a6f13f5f2dc4e3cecc38c91afd432131..01d2b282751f387cfa9c8887cdeb48090c96bff4 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" @@ -35,6 +36,134 @@ limitations under the License. namespace tensorflow { namespace { +// Returns the names of the "then" and "else" functions for the XlaIf node in a +// graph. +Status FindIfThenAndElse(const GraphDef& graph, NameAttrList* then_fn, + NameAttrList* else_fn) { + for (const NodeDef& node : graph.node()) { + if (node.op() == "XlaIf") { + const NameAttrList* result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); + *then_fn = *result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result)); + *else_fn = *result; + return Status::OK(); + } + } + return errors::NotFound("No XlaIf node found in graph"); +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = array_ops.placeholder(dtypes.int32) +// z = control_flow_ops.cond( +// math_ops.less(y, x), lambda: math_ops.multiply(y, 17), +// lambda: math_ops.add(x, 23)) +TEST(FunctionalizeControlFlow, Conditional) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less); + + auto identity_t = + ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true); + auto seventeen = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity_t), 17); + auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less); + auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true, + seventeen); + + auto identity_f = + ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false); + auto twenty_three = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity_f), 23); + auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); + auto add = ops::Add(scope.WithOpName("cond/false/add"), + switch_3.output_false, twenty_three); + + auto merge = ops::Merge(scope.WithOpName("cond/Merge"), + std::initializer_list{add, mul}); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + NameAttrList then_fn; + NameAttrList else_fn; + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &then_fn, &else_fn)); + InstantiationResultForTest else_result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto if_op = ops::XlaIf(scope.WithOpName("cond/Merge_If"), less, + std::initializer_list{less, y, x}, then_fn, + else_fn, {DT_INT32}); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // then body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); + auto cond = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity), 17); + auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // else body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); + auto cond_1 = ops::Const( + scope.WithOpName("cond_1").WithControlDependencies(identity), 23); + auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Returns the names of the "cond" and "body" functions for the While node // in a graph. Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, @@ -168,6 +297,108 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } } +// Tests functionalizing OneLoopVar where the loop value is not used post the +// loop. +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) +TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = + ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto next_iteration = + ops::NextIteration(scope.WithOpName("while/NextIteration"), add); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Graph: // x = array_ops.placeholder(dtypes.int32) // y = array_ops.placeholder(dtypes.int32) diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..6f2f59d98fb03ffd7db19aaa70774ecfa4b78ce9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/graph_compiler.h" + +#include +#include +#include +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +namespace { +Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, + const std::vector& expressions, + std::vector* args) { + auto builder = ctx->builder(); + std::vector compile_time_constant_flags(expressions.size()); + + TF_RETURN_IF_ERROR( + BackwardsConstAnalysis(*graph, &compile_time_constant_flags)); + + args->resize(expressions.size()); + for (int i = 0; i < args->size(); ++i) { + XlaCompiler::Argument& arg = (*args)[i]; + arg.type = ctx->input_type(i); + + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + + if (arg.type == DT_RESOURCE) { + return errors::InvalidArgument( + "Resource as function argument is not yet implemented."); + } else if (expressions[i]->has_constant_value()) { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = expressions[i]->constant_value(); + } else if (compile_time_constant_flags[i]) { + arg.kind = XlaCompiler::Argument::kConstant; + TF_RET_CHECK(expressions[i]->resource() == nullptr) + << "Input with resource is not yet implemented."; + TF_ASSIGN_OR_RETURN(auto literal, + builder->ComputeConstant(expressions[i]->handle())); + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); + } else { + arg.kind = XlaCompiler::Argument::kParameter; + } + } + return Status::OK(); +} +} // namespace +Status GraphCompiler::Compile() { + std::vector bindings(graph_->num_node_ids()); + std::vector topo_sorted_nodes; + // XLA requires determinism, generate a stable ordering from DFS. + GetReversePostOrder(*graph_, &topo_sorted_nodes, + /*stable_comparator=*/NodeComparatorName()); + + OpKernelContext::Params params; + PartiallySetupParams(¶ms); + + for (Node* n : topo_sorted_nodes) { + // Set up bindings. + NodeBinding& binding = bindings[n->id()]; + binding.node = n; + Status s = flib_->CreateKernel(n->def(), &binding.op_kernel); + binding.output_attrs.resize(n->num_outputs()); + if (!s.ok()) { + binding.op_kernel = nullptr; + s = AttachDef(s, *n); + LOG(ERROR) << "Executor failed to create kernel. " << s; + return s; + } + } + + // Bindings are initialized by the size of graph_->num_node_ids. However, the + // graph may contain dead nodes that still hold a valid node id. Thus + // graph_->num_node_ids could be larger than number of topo sorted nodes. + TF_RET_CHECK(bindings.size() >= topo_sorted_nodes.size()); + + for (Node* n : topo_sorted_nodes) { + TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch()) + << "Not supported node: " << n->DebugString(); + NodeBinding& binding = bindings[n->id()]; + params.op_kernel = binding.op_kernel; + params.output_attr_array = binding.output_attrs.data(); + + // tensor_inputs_ is a buffer reused across graph traversal. We clean up and + // reinitialize the buffer before we visit a new node. + tensor_inputs_.clear(); + tensor_inputs_.resize(n->num_inputs()); + + // Set up inputs from outputs of previous nodes. + for (auto* e : n->in_edges()) { + if (e->IsControlEdge()) continue; + Node* src = e->src(); + tensor_inputs_[e->dst_input()] = + bindings[src->id()].tensor_values[e->src_output()]; + } + + OpKernelContext op_context(¶ms, n->num_outputs()); + if (IsFunctional(n)) { + TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context)); + } else { + device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context); + Status s = op_context.status(); + TF_RETURN_IF_ERROR(s); + } + + // Set up outputs. Also check if outputs from the previous computation is + // valid. + for (int o = 0; o < n->num_outputs(); ++o) { + const auto tensor_val = op_context.release_output(o); + if (*op_context.is_output_dead() || tensor_val.tensor == nullptr) { + return errors::Internal("Missing xla_context ", o, "-th output from ", + (*op_context.is_output_dead() ? "(dead)" : ""), + SummarizeNode(*n)); + } + binding.tensor_values.push_back(tensor_val); + } + } + + // Clean up tensor data and op kernels. + for (NodeBinding& binding : bindings) { + delete binding.op_kernel; + for (auto& t : binding.tensor_values) { + if (!t.is_ref()) { + delete t.tensor; + } + } + } + return Status::OK(); +} + +bool GraphCompiler::IsFunctional(Node* n) { + return n->type_string() == FunctionLibraryDefinition::kGradientOp || + (flib_->GetFunctionLibraryDefinition()->Find(n->def().op()) != + nullptr); +} + +Status GraphCompiler::CompileFunctionalNode(Node* n, + OpKernelContext* op_context) { + TF_RET_CHECK(IsFunctional(n)); + // For functional nodes, compile them using compiler from the context and call + // into the functions. + XlaOpKernelContext xla_op_context(op_context); + + XlaCompiler* compiler = xla_op_context.compiler(); + + NameAttrList func; + if (flib_->GetFunctionLibraryDefinition()->Find(n->def().op())) { + func.set_name(n->def().op()); + } else { + func.set_name(FunctionLibraryDefinition::kGradientOp); + } + *func.mutable_attr() = n->def().attr(); + + std::vector expressions; + + for (auto tensor : tensor_inputs_) { + auto expression = + reinterpret_cast(tensor->tensor_data().data()); + expressions.push_back(expression); + } + + // Prepare the arguments and compile the function. + std::vector arguments; + const FunctionBody* fbody; + TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody)); + + auto graph = compiler->GetGraph(fbody); + + TF_RETURN_IF_ERROR( + PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + + XlaCompiler::CompilationResult result; + + TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(), + func, arguments, &result)); + + TF_RET_CHECK(arguments.size() == expressions.size()); + + std::vector handles; + for (int64 i = 0; i < expressions.size(); ++i) { + if (arguments[i].kind == XlaCompiler::Argument::kConstant) { + continue; + } + handles.push_back(expressions[i]->handle()); + } + + XlaContext& context = XlaContext::Get(op_context); + auto* b = context.builder(); + + auto output_handle = b->Call(*result.computation, handles); + // The output handle of `Call` computation is a tuple type. Unzip it so + // that it can fit into future computations. + for (int64 i = 0; i < n->num_outputs(); ++i) { + if (result.outputs[i].is_constant) { + xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); + } else { + xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i)); + } + } + return b->first_error(); +} + +void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) { + params->device = device_; + params->inputs = &tensor_inputs_; + params->step_container = step_container_; + params->resource_manager = device_->resource_manager(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..ccf9351642fb21ab8f14bedd616fdb92215a6492 --- /dev/null +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +// GraphCompiler compiles the graph in topological order in the current +// thread. It also resolves the nondeterminism in the graph by enforcing a +// total order on all inputs to a node. This abstraction helps us create the +// same XLA computation given two structurally equivalent TensorFlow graphs. +// If a function call is visited during the graph traversal, it is then +// compiled through the xla_context into a computation and a `Call` operation +// is inserted to call into that computation. +// +// Note: GraphCompiler was created to remove our dependency to TF Executor in +// the history. There are still some todos so that we can completely decouple +// from Executor. +// +// TODO(yunxing): Remove usage of XlaCompilationDevice. +// +// TODO(yunxing): Remove the hack that wraps XlaExpression within a tensor now +// that we don't use TF Executor to pass around a tensor. +// +// TODO(yunxing): Make XlaOpkernel not a subclass of OpKernel so that it can +// handle a XlaExpression directly instead of a Tensor. This may require our own +// op registration infrastructure instead of FunctionLibraryRuntime. +class GraphCompiler { + public: + GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device, + Graph* graph, FunctionLibraryRuntime* flib, + ScopedStepContainer* step_container) + : xla_context_(xla_context), + device_(device), + graph_(graph), + flib_(flib), + step_container_(step_container) {} + + // Compiles the graph. The results are written in `xla_context` that is passed + // into the compiler. + Status Compile(); + + private: + // NodeBinding is a wrapper on a `Node` that also contains computed + // TensorValue. + struct NodeBinding { + const Node* node; + // Kernel for this node, to be filled by CreateKernel. + // TODO(yunxing): Switching this to unique_ptr and understand why it crashes + // on GPU devices. + OpKernel* op_kernel; + // Output values of this node. + std::vector tensor_values; + // Attributes of the outputs. + gtl::InlinedVector output_attrs; + }; + + // Partially sets params. This partially set params can be reused + // across multple nodes visit. + void PartiallySetupParams(OpKernelContext::Params* params); + + // Tests if a node is a functional node. A functional node represents a + // defined computation and should be compiled using `compiler_`. + bool IsFunctional(Node* n); + + // Compiles a functional node and writes result to OpkernelContext. A + // functional node represents a defined computation and should be compiled + // using `compiler_`. + Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); + + XlaContext* xla_context_; + XlaCompilationDevice* device_; + Graph* graph_; + FunctionLibraryRuntime* flib_; + ScopedStepContainer* step_container_; + // A buffer to hold tensor inputs to a node, this is reused across the graph + // traversal. + gtl::InlinedVector tensor_inputs_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 6a0c4fef7574319de2837f7a17867a41784d56b4..f44d61de686278a46b6780eaa974a7939d42a481 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -132,8 +132,6 @@ tf_kernel_library( name = "xla_cpu_only_ops", srcs = ["index_ops_cpu.cc"], deps = [ - ":gather_op_kernel_float_int32", - ":gather_op_kernel_float_int64", ":index_ops_kernel_argmax_float_1d", ":index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/tf2xla:common", @@ -149,34 +147,6 @@ tf_kernel_library( ], ) -cc_library( - name = "gather_op_kernel_float_int32", - srcs = ["gather_op_kernel_float_int32.cc"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/core:framework_lite", - "//tensorflow/core/kernels:bounds_check", - "//tensorflow/core/kernels:gather_functor_hdr", - "//third_party/eigen3", - ], - alwayslink = 1, -) - -cc_library( - name = "gather_op_kernel_float_int64", - srcs = ["gather_op_kernel_float_int64.cc"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/core:framework_lite", - "//tensorflow/core/kernels:bounds_check", - "//tensorflow/core/kernels:gather_functor_hdr", - "//third_party/eigen3", - ], - alwayslink = 1, -) - cc_library( name = "index_ops_kernel_argmax_float_1d", srcs = ["index_ops_kernel_argmax_float_1d.cc"], diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 58538b45137b26ed5aa296eb6c1077e88aea72b9..d635507989bbf78a073be8a50d943dba8688438e 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -96,8 +96,10 @@ static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b, XLA_MAKE_BINARY(FloorMod, FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -XLA_MAKE_BINARY(LogicalAnd, b->LogicalAnd(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LogicalOr, b->LogicalOr(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseAnd, b->And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseOr, b->Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LogicalAnd, b->And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LogicalOr, b->Or(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions)); diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 0091b66d28ad62fcd5c0f3b09e90fed8347bb661..885f716afafca7ba23770e38f6693eed1ba50982 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -179,8 +179,10 @@ class ConvOp : public XlaOpKernel { xla::ConvolutionDimensionNumbers dims; std::vector window_strides; - dims.set_batch_dimension(GetTensorBatchDimIndex(num_dims(), data_format_)); - dims.set_feature_dimension(feature_dim); + dims.set_input_batch_dimension(batch_dim); + dims.set_output_batch_dimension(batch_dim); + dims.set_input_feature_dimension(feature_dim); + dims.set_output_feature_dimension(feature_dim); for (int i = 0; i < num_spatial_dims_; ++i) { int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); dims.add_spatial_dimensions(input_dim); @@ -285,8 +287,10 @@ class ConvBackpropInputOp : public XlaOpKernel { // comment at the top of conv_grad_ops.h for details. xla::ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(batch_dim); - dnums.set_feature_dimension(feature_dim); + dnums.set_input_batch_dimension(batch_dim); + dnums.set_output_batch_dimension(batch_dim); + dnums.set_input_feature_dimension(feature_dim); + dnums.set_output_feature_dimension(feature_dim); // TF filter shape is [ H, W, ..., inC, outC ] // Transpose the input and output features for computing the gradient. @@ -419,8 +423,10 @@ class ConvBackpropFilterOp : public XlaOpKernel { // Each spatial entry has size in_depth * batch // Swap n_dim and c_dim in the activations. - dnums.set_batch_dimension(c_dim); - dnums.set_feature_dimension(n_dim); + dnums.set_input_batch_dimension(c_dim); + dnums.set_output_batch_dimension(c_dim); + dnums.set_input_feature_dimension(n_dim); + dnums.set_output_feature_dimension(n_dim); // The gradients become the RHS of the convolution. // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index dde7898015e73190c96fa6effddfd3fc892264ea..7349dcb987cd88c423570889c0502d1a0bd12c52 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -199,6 +199,7 @@ class DynamicStitchOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("DynamicStitch"), DynamicStitchOp); +REGISTER_XLA_OP(Name("ParallelDynamicStitch"), DynamicStitchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 2c7d445600bf81873b0ed44fc8dccb23dcc902d6..db449ec3451d90fe8dce2bef5bea3795dd908277 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -30,7 +30,7 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( XlaOpKernelContext* context, const xla::ComputationDataHandle& input, const TensorShape& input_shape, const xla::ComputationDataHandle& indices, const TensorShape& indices_shape, int64 axis, DataType dtype, - xla::ComputationBuilder* builder) { + DataType index_type, xla::ComputationBuilder* builder) { // Although the indices Tensor is flattened into rank 1 during the lookup, // and each scalar entry is used as an index into the first dimension of the // input, the output is returned with shape: @@ -80,22 +80,23 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Specify the shape of the loop-carried Tensor tuple. xla::PrimitiveType ptype; TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); + xla::PrimitiveType idxtype; + TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &idxtype)); std::vector tuple_shapes( {// The iteration counter i is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), + xla::ShapeUtil::MakeShape(idxtype, {}), // The input array has shape input_shape. Loop invariant. xla::ShapeUtil::MakeShape(ptype, input_shape.dim_sizes()), // The gather indices are reshaped to rank 1. Loop invariant. - xla::ShapeUtil::MakeShape(xla::S32, {num_indices}), + xla::ShapeUtil::MakeShape(idxtype, {num_indices}), // The output array is rank >= 3, and is updated on each loop iteration. xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())}); xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); // Construct the initial values of the loop-carried Tensors. - auto init_i = builder->ConstantR0(0); - auto init_out = - builder->Broadcast(builder->ConstantLiteral(xla::Literal::Zero(ptype)), - loop_out_shape.dim_sizes()); + auto init_i = XlaHelpers::Zero(builder, index_type); + auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), + loop_out_shape.dim_sizes()); // Flatten the indices into 1-D for ease of iteration. auto indices_1d = builder->Reshape(indices, {num_indices}); auto init = builder->Tuple({init_i, input, indices_1d, init_out}); @@ -105,7 +106,7 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( "GatherWhileCond"); condb.Lt(condb.GetTupleElement( condb.Parameter(0, tuple_shape, "GatherWhileTuple"), 0), - condb.ConstantR0(num_indices)); + XlaHelpers::IntegerLiteral(&condb, index_type, num_indices)); auto cond_status = condb.Build(); auto cond = cond_status.ConsumeValueOrDie(); @@ -127,7 +128,7 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Slice from the input array. auto index = bodyb.DynamicSlice(indices, bodyb.Reshape(i, {1}), {1}); auto start_indices = bodyb.Pad( - bodyb.Reshape(index, {1}), bodyb.ConstantR0(0), + bodyb.Reshape(index, {1}), XlaHelpers::Zero(&bodyb, index_type), xla::MakeEdgePaddingConfig( {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}})); auto slice_i = bodyb.Reshape( @@ -136,7 +137,8 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Construct the index into the R3+ output Tensor 0, ..., , 0, ... std::vector out_index_vals( - loop_out_shape.dims(), bodyb.ConstantR1({0})); + loop_out_shape.dims(), + bodyb.Reshape(XlaHelpers::Zero(&bodyb, index_type), {1})); out_index_vals[input_shape_pre_axis.dims() + extra_dims] = bodyb.Reshape(i, {1}); auto out_index = bodyb.ConcatInDim(out_index_vals, 0); @@ -144,8 +146,8 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Update the output Tensor auto updated_output = bodyb.DynamicUpdateSlice(output, slice_i, out_index); - bodyb.Tuple({bodyb.Add(i, bodyb.ConstantR0(1)), input, indices, - updated_output}); + bodyb.Tuple({bodyb.Add(i, XlaHelpers::One(&bodyb, index_type)), input, + indices, updated_output}); } auto body_status = bodyb.Build(); auto body = body_status.ConsumeValueOrDie(); @@ -156,124 +158,6 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( return builder->Reshape(gather_output, out_shape.dim_sizes()); } -namespace { - -class GatherOpCustomCall : public XlaOpKernel { - public: - explicit GatherOpCustomCall(OpKernelConstruction* context) - : XlaOpKernel(context) {} - - void Compile(XlaOpKernelContext* context) override { - const TensorShape params_shape = context->InputShape(0); - const auto params_dims = params_shape.dims(); - const TensorShape indices_shape = context->InputShape(1); - OP_REQUIRES( - context, TensorShapeUtils::IsVectorOrHigher(params_shape), - errors::InvalidArgument("params must be at least 1 dimensional")); - - DataType index_type = input_type(1); - OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("index must be int32 or int64")); - - // GatherV2 added an axis argument. We support both Gather and GatherV2 in - // this kernel by defaulting axis to 0 if there are 2 inputs. - int64 axis = 0; - if (context->num_inputs() == 3) { - const TensorShape axis_shape = context->InputShape(2); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), - errors::InvalidArgument("axis must be scalar")); - DataType axis_type = input_type(2); - OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, - errors::InvalidArgument("axis must be int32 or int64")); - - xla::Literal literal; - OP_REQUIRES_OK(context, context->ConstantInput(2, &literal)); - int64 axis_input = axis_type == DT_INT32 ? literal.Get({}) - : literal.Get({}); - axis = axis_input < 0 ? axis_input + params_dims : axis_input; - OP_REQUIRES(context, 0 <= axis && axis < params_dims, - errors::InvalidArgument("Expected axis in the range [", - -params_dims, ", ", params_dims, - "), but got ", axis_input)); - } - - // Check that we have enough index space. - const int64 limit = index_type == DT_INT32 - ? std::numeric_limits::max() - : std::numeric_limits::max(); - OP_REQUIRES(context, params_shape.dim_size(axis) <= limit, - errors::InvalidArgument( - "params.shape[", axis, "] too large for ", - DataTypeString(index_type), - " indexing: ", params_shape.dim_size(axis), " > ", limit)); - - // The result shape is params.shape[0:axis] + indices.shape + - // params.shape[axis + 1:]. - TensorShape result_shape; - int64 outer_size = 1; - int64 inner_size = 1; - for (int i = 0; i < axis; i++) { - result_shape.AddDim(params_shape.dim_size(i)); - outer_size *= params_shape.dim_size(i); - } - result_shape.AppendShape(indices_shape); - for (int i = axis + 1; i < params_dims; i++) { - result_shape.AddDim(params_shape.dim_size(i)); - inner_size *= params_shape.dim_size(i); - } - - XlaContext& tc = XlaContext::Get(context); - OP_REQUIRES( - context, tc.allow_cpu_custom_calls(), - errors::InvalidArgument("Gather op requires CustomCall on CPU")); - - xla::ComputationBuilder& b = *context->builder(); - - // Call gather_xla_float_kernel (from gather_op_kernel_float.cc). - // XLA passes to the function, so it is not included here. - std::vector args; - args.push_back(tc.GetOrCreateRuntimeContextParameter()); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR0(indices_shape.num_elements()))); - args.push_back( - b.ConstantLiteral(*xla::Literal::CreateR0(outer_size))); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR0(params_shape.dim_size(axis)))); - args.push_back( - b.ConstantLiteral(*xla::Literal::CreateR0(inner_size))); - args.push_back(context->Input(0)); - args.push_back(context->Input(1)); - - xla::Shape xla_out_shape; - OP_REQUIRES_OK( - context, TensorShapeToXLAShape(DT_FLOAT, result_shape, &xla_out_shape)); - - // Call the custom code with args: - xla::ComputationDataHandle output; - if (index_type == DT_INT32) { - output = b.CustomCall("gather_float_int32_xla_impl", args, xla_out_shape); - } else { - output = b.CustomCall("gather_float_int64_xla_impl", args, xla_out_shape); - } - - context->SetOutput(0, output); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(GatherOpCustomCall); -}; - -REGISTER_XLA_OP(Name("Gather") - .TypeConstraint("Tparams", DT_FLOAT) - .Device(DEVICE_CPU_XLA_JIT), - GatherOpCustomCall); -REGISTER_XLA_OP(Name("GatherV2") - .TypeConstraint("Tparams", DT_FLOAT) - .Device(DEVICE_CPU_XLA_JIT), - GatherOpCustomCall); - -} // namespace - GatherOpDynamicSlice::GatherOpDynamicSlice(OpKernelConstruction* context) : XlaOpKernel(context) {} @@ -303,20 +187,17 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) { ", ", params_dims, "), but got ", axis)); } - xla::ComputationDataHandle gather = - XlaComputeGatherDynamicSlice(context, input, input_shape, indices, - indices_shape, axis, DT_FLOAT, builder); + DataType index_type = input_type(1); + OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, + errors::InvalidArgument("indices must be int32 or int64")); + + xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( + context, input, input_shape, indices, indices_shape, axis, DT_FLOAT, + index_type, builder); context->SetOutput(0, gather); } -REGISTER_XLA_OP(Name("Gather") - .TypeConstraint("Tparams", DT_FLOAT) - .Device(DEVICE_GPU_XLA_JIT), - GatherOpDynamicSlice); - -REGISTER_XLA_OP(Name("GatherV2") - .TypeConstraint("Tparams", DT_FLOAT) - .Device(DEVICE_GPU_XLA_JIT), - GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("GatherV2"), GatherOpDynamicSlice); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index 5623c4d1c2be18b73696052c7fad822355a30f77..2c80395c56d73adad7dc1679ba6423fbe103605a 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -28,11 +28,13 @@ namespace tensorflow { // Adds to builder an XLA computation that performs a gather on input (of // shape input_shape) keyed on indices (of shape indices_shape). +// +// index_type must be must be DT_INT32 or DT_INT64. xla::ComputationDataHandle XlaComputeGatherDynamicSlice( XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, const TensorShape& input_shape, const xla::ComputationDataHandle& indices, const TensorShape& indices_shape, int64 axis, DataType dtype, - xla::ComputationBuilder* builder); + DataType index_type, xla::ComputationBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc deleted file mode 100644 index 33b1b087d00d8263cd80f7d5d879401e4ed6c0fb..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/gather_functor.h" -#include "tensorflow/core/platform/dynamic_annotations.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { - -EIGEN_STRONG_INLINE void gather_float_int32_xla_impl(float* out, void** data) { - // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 7 * sizeof(void*)); - - int64 indices_size = *static_cast(data[1]); - int64 params_x = *static_cast(data[2]); - int64 params_y = *static_cast(data[3]); - int64 params_z = *static_cast(data[4]); - - float* in = static_cast(data[5]); - - int32* indices = static_cast(data[6]); - Eigen::DSizes in_eig_sizes; - in_eig_sizes[0] = params_x; - in_eig_sizes[1] = params_y; - in_eig_sizes[2] = params_z; - tensorflow::TTypes::ConstTensor in_eig(in, in_eig_sizes); - - Eigen::DSizes indices_eig_sizes; - indices_eig_sizes[0] = indices_size; - tensorflow::TTypes::ConstFlat indices_eig(indices, indices_eig_sizes); - - Eigen::DSizes out_eig_sizes; - out_eig_sizes[0] = params_x; - out_eig_sizes[1] = indices_size; - out_eig_sizes[2] = params_z; - tensorflow::TTypes::Tensor out_eig(out, out_eig_sizes); - - tensorflow::functor::GatherFunctorCPU f; - const int64 bad_i = f(in_eig, indices_eig, out_eig); - if (bad_i != -1) { - tensorflow::XlaLocalRuntimeContext* runtime_context = - static_cast(data[0]); - runtime_context->error = true; - runtime_context->error_msg = "Invalid index for gather"; - for (int i = 0; i < out_eig.size(); ++i) out[i] = 0; - } -} - -} // namespace tensorflow - -// Implements gather on CPU. This is called by an XLA custom call, set up by -// gather_op.cc. -extern "C" void TF_EXPORT gather_float_int32_xla_impl(float* out, void** data) { - tensorflow::gather_float_int32_xla_impl(out, data); -} diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc deleted file mode 100644 index 5e2d872ce0b28ab479c73ed1fea5f32804c21e22..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/gather_functor.h" -#include "tensorflow/core/platform/dynamic_annotations.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { - -EIGEN_STRONG_INLINE void gather_float_int64_xla_impl(float* out, void** data) { - // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 7 * sizeof(void*)); - - int64 indices_size = *static_cast(data[1]); - int64 params_x = *static_cast(data[2]); - int64 params_y = *static_cast(data[3]); - int64 params_z = *static_cast(data[4]); - - float* in = static_cast(data[5]); - - int64* indices = static_cast(data[6]); - Eigen::DSizes in_eig_sizes; - in_eig_sizes[0] = params_x; - in_eig_sizes[1] = params_y; - in_eig_sizes[2] = params_z; - tensorflow::TTypes::ConstTensor in_eig(in, in_eig_sizes); - - Eigen::DSizes indices_eig_sizes; - indices_eig_sizes[0] = indices_size; - tensorflow::TTypes::ConstFlat indices_eig(indices, indices_eig_sizes); - - Eigen::DSizes out_eig_sizes; - out_eig_sizes[0] = params_x; - out_eig_sizes[1] = indices_size; - out_eig_sizes[2] = params_z; - tensorflow::TTypes::Tensor out_eig(out, out_eig_sizes); - - tensorflow::functor::GatherFunctorCPU f; - const int64 bad_i = f(in_eig, indices_eig, out_eig); - if (bad_i != -1) { - tensorflow::XlaLocalRuntimeContext* runtime_context = - static_cast(data[0]); - runtime_context->error = true; - runtime_context->error_msg = "Invalid index for gather"; - for (int i = 0; i < out_eig.size(); ++i) out[i] = 0; - } -} - -} // namespace tensorflow - -// Implements gather on CPU. This is called by an XLA custom call, set up by -// gather_op.cc. -extern "C" void TF_EXPORT gather_float_int64_xla_impl(float* out, void** data) { - tensorflow::gather_float_int64_xla_impl(out, data); -} diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 87d3d64a4e9c07b8effce7583c4189b8c737d433..d2b1f7913ecc9113284827b53de8fb0e5b711322 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -24,7 +24,9 @@ class IdentityOp : public XlaOpKernel { explicit IdentityOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* ctx) override { - ctx->SetOutput(0, ctx->Input(0)); + for (int i = 0; i < ctx->num_inputs(); ++i) { + ctx->SetOutput(i, ctx->Input(i)); + } } private: @@ -35,6 +37,7 @@ class IdentityOp : public XlaOpKernel { // dummy operator using CompilationOnly(). REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp); +REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 66b99665cbefd9ffd2acabe6eb296f485ca6a59d..2421825ead17a3acee9f145f00904d382fb656f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -140,7 +140,7 @@ class TruncatedNormalOp : public XlaOpKernel { xla::ComputationBuilder* b) { xla::ComputationDataHandle too_large = b->Gt(candidate, two_sd(false, b)); xla::ComputationDataHandle too_small = b->Lt(candidate, two_sd(true, b)); - return b->LogicalOr(too_large, too_small); + return b->Or(too_large, too_small); }; // The algorithm we're using is roughly: diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index dae2eb9d2a92ef8d4eabb8d6f9a79758c42d446d..647b6274083cf8886af6c451b746416445a4a2b2 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -129,7 +129,7 @@ class AllOp : public XlaReductionOp { void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { - builder->LogicalAnd(scalar_lhs, scalar_rhs); + builder->And(scalar_lhs, scalar_rhs); } }; @@ -147,7 +147,7 @@ class AnyOp : public XlaReductionOp { void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { - builder->LogicalOr(scalar_lhs, scalar_rhs); + builder->Or(scalar_lhs, scalar_rhs); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index a137d28118e6b4c66c70253817be9b3f0b75088a..12a35529992e6160566046dd28f9321c88afec91 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -77,9 +77,9 @@ class Relu6GradOp : public XlaOpKernel { b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); const auto six = b->Broadcast( XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); - auto out = b->Select( - b->LogicalAnd(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), - ctx->Input(0), zero); + auto out = + b->Select(b->And(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), + ctx->Input(0), zero); ctx->SetOutput(0, out); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index ed818c56ed0e6fa41374234d6f6712a2bbda94e2..5172781c0d05b6682fe92086654e3b86961949ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index a0d8ab4d73f7491fe96299c6cdc918f00a3d7a97..750a4c2dec8154f97f307978b3d8884271292279 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -202,7 +202,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { // NaN otherwise; then add that vector to the labels to force out-of-range // values to NaNs. xla::ComputationDataHandle nan_or_zero = builder->Select( - builder->LogicalAnd( + builder->And( builder->Le(XlaHelpers::Zero(builder, indices_type), indices), builder->Lt(indices, XlaHelpers::IntegerLiteral( builder, indices_type, depth))), diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 44ee81461e5b31f15594c0dfb86f7219f9875768..795eb1794f577e0f7fd2a2068878e540ff0c1a1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -33,13 +33,16 @@ class SplitOp : public XlaOpKernel { explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + const int32 num_split = num_outputs(); const TensorShape index_shape = ctx->InputShape(0); + const TensorShape input_shape = ctx->InputShape(1); + xla::Literal literal_index; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index)); - int32 split_dim; + int32 split_dim_orig; if (index_shape.dims() == 0) { - split_dim = literal_index.Get({}); + split_dim_orig = literal_index.Get({}); } else { OP_REQUIRES( ctx, index_shape.dims() == 1, @@ -49,27 +52,28 @@ class SplitOp : public XlaOpKernel { ctx, index_shape.dim_size(0) == 1, errors::InvalidArgument("split_index input to Split Op must be a " "scalar or a vector with 1 element")); - split_dim = literal_index.Get({0}); + split_dim_orig = literal_index.Get({0}); } - const int32 num_split = num_outputs(); - const TensorShape input_shape = ctx->InputShape(1); - - OP_REQUIRES( - ctx, 0 <= split_dim && split_dim < input_shape.dims(), - errors::InvalidArgument("0 <= split_dim < number of input dimensions (", - input_shape.dims(), "), but got ", split_dim)); + int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() + : split_dim_orig; + OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), + errors::InvalidArgument("-input rank(-", input_shape.dims(), + ") <= split_dim < input rank (", + input_shape.dims(), "), but got ", + split_dim_orig)); OP_REQUIRES( ctx, num_split > 0, errors::InvalidArgument( "Number of ways to split should be > 0, but got ", num_split)); - OP_REQUIRES(ctx, input_shape.dim_size(split_dim) % num_split == 0, - errors::InvalidArgument( - "Number of ways to split should evenly divide the split " - "dimension, but got split_dim ", - split_dim, " (size = ", input_shape.dim_size(split_dim), - ") ", "and num_split ", num_split)); + OP_REQUIRES( + ctx, input_shape.dim_size(split_dim) % num_split == 0, + errors::InvalidArgument( + "Number of ways to split should evenly divide the split " + "dimension, but got split_dim ", + split_dim_orig, " (size = ", input_shape.dim_size(split_dim), ") ", + "and num_split ", num_split)); // All the slices are the same size: this is the size along the // split dimension. diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index e2d3d40813b63360d37392b59b3c68cc478b077b..351fda251798e43b607fb445f2c98abd57b3d86b 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -307,11 +307,12 @@ class TensorArrayGatherOp : public XlaOpKernel { OP_REQUIRES(ctx, indices_shape.dims() == 1, errors::InvalidArgument("indices must be rank 1")); auto indices = ctx->Input(1); + DataType index_type = ctx->input_type(1); xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, b); + ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b); ctx->SetOutput(0, gather); } diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 6b8f5ec7b33cd448a7b06c5dfe4aac288e53e9c9..651bbe2b405df66cb6aff1ba7fe3957eba94d610 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -87,7 +87,8 @@ XLAJIT_MAKE_UNARY(Log, b->Log(x)); // TODO(b/34703906): use a more accurate implementation of log1p. XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); -XLAJIT_MAKE_UNARY(LogicalNot, b->LogicalNot(x)); +XLAJIT_MAKE_UNARY(Invert, b->Not(x)); +XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); // Implements Banker's rounding: numbers that are equidistant between two @@ -104,9 +105,9 @@ static xla::ComputationDataHandle Round(xla::ComputationBuilder* b, auto nearest_even_int = b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x)))); auto is_odd = b->Eq(nearest_even_int, one); - return b->Select(b->LogicalOr(b->Gt(fraction, half), - b->LogicalAnd(b->Eq(fraction, half), is_odd)), - b->Add(round_val, one), round_val); + return b->Select( + b->Or(b->Gt(fraction, half), b->And(b->Eq(fraction, half), is_odd)), + b->Add(round_val, one), round_val); } XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x)); @@ -129,8 +130,28 @@ XLAJIT_MAKE_UNARY(Sign, b->Sign(x)); XLAJIT_MAKE_UNARY(Sinh, b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Softplus, - b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0))))); + +static xla::ComputationDataHandle Softplus( + xla::ComputationBuilder* b, DataType dtype, + const xla::ComputationDataHandle& features) { + xla::ComputationDataHandle threshold = + b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)), + XlaHelpers::FloatLiteral(b, dtype, 2.0)); + // Value above which exp(x) may overflow, but softplus(x) == x + // is within machine epsilon. + xla::ComputationDataHandle too_large = b->Gt(features, b->Neg(threshold)); + // Value below which exp(x) may underflow, but softplus(x) == exp(x) + // is within machine epsilon. + xla::ComputationDataHandle too_small = b->Lt(features, threshold); + xla::ComputationDataHandle features_exp = b->Exp(features); + xla::ComputationDataHandle output = b->Select( + too_large, features, + b->Select(too_small, features_exp, + b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype))))); + return output; +} +XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x)); + // softsign(x) = x / (abs(x) + 1) XLAJIT_MAKE_UNARY(Softsign, b->Div(x, diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 4ae983854780e49d8c7af29f218d3fbb7db225e2..b19ea22f50d2dd44e8d1d81f5930263f364030e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -111,9 +111,10 @@ class ResourceGatherOp : public XlaOpKernel { auto indices = ctx->Input(1); auto indices_shape = ctx->InputShape(1); + DataType index_type = ctx->input_type(1); xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( ctx, resource_handle, resource_shape, indices, indices_shape, 0, - resource_dtype, builder); + resource_dtype, index_type, builder); ctx->SetOutput(0, gather); } }; diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc index c1005405f9a9b09e4a6480332861d0cce2c52291..4a669f8e6eaf644f119f3c0a66f29d9f2c9a9d16 100644 --- a/tensorflow/compiler/tf2xla/ops/functional_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/functional_ops.cc @@ -34,14 +34,41 @@ output = input; While (Cond(output)) { output = Body(output) } input: A list of input tensors whose types are T. output: A list of output tensors whose types are T. cond: A function takes 'input' and returns a tensor. If the tensor is - a scalar of non-boolean, the scalar is converted to a boolean - according to the following rule: if the scalar is a numerical - value, non-zero means True and zero means False; if the scalar is - a string, non-empty means True and empty means False. If the - tensor is not a scalar, non-emptiness means True and False - otherwise. + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. body: A function that takes a list of tensors and returns another list of tensors. Both lists have the same types as specified by T. )doc"); +// TODO(b/37549631) setting the If Op to always be stateful is too +// conservative. +REGISTER_OP("XlaIf") + .Input("cond: Tcond") + .Input("inputs: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("then_branch: func") + .Attr("else_branch: func") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = cond ? then_branch(inputs) : else_branch(inputs). + +cond: A boolean scalar. +inputs: A list of input tensors. +output: A list of tensors returned by either then_branch(inputs) or + else_branch(inputs). The input shapes of the then_branch and + else_branch must match. +then_branch: A function takes 'inputs' and returns a list of tensors, + whose types are the same as what else_branch returns. +else_branch: A function takes 'inputs' and returns a list of tensors. + whose types are the same as what then_branch returns. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc index b6947bfe570c75dd0c7c6301b972e2012bae26bd..4b41c16a8b3fdc0c3412c76d29d3ec2b7bdfd0aa 100644 --- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc @@ -37,7 +37,14 @@ REGISTER_OP("_XLARecv") .Attr("tensor_name: string") .Attr("shape: shape") .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + TensorShape shape_attr; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); + c->set_output(0, s); + return Status::OK(); + }) .Doc(R"doc( Receives the named tensor from another XLA computation. diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index b7213a6cc1e4066f98523ec57681f4c0651f71b5..a14c93a2b9494b89f579bc20ee0510c136f8f01b 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -255,11 +255,10 @@ Status CreateXlaArgs(const Graph& graph, Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, xla::Computation* computation, bool* requires_runtime_context) { - // Create a device and context to convert the graph into an XLA computation. XlaOpRegistry::RegisterCompilationKernels(); - // Populate the context with args from the graph. for (Node* node : graph->nodes()) { - node->set_assigned_device_name(DEVICE_CPU_XLA_JIT); + node->set_assigned_device_name( + strings::StrCat("/device:", DEVICE_CPU_XLA_JIT)); } std::vector xla_args; TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index b54848f342406c9211c06664fd1f6c0783e0891f..c6984887766e7778d2f8f2fdbd0d626cf9451d86 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -43,6 +43,12 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_UINT16: *type = xla::U16; return Status::OK(); + case tensorflow::DT_UINT32: + *type = xla::U32; + return Status::OK(); + case tensorflow::DT_UINT64: + *type = xla::U64; + return Status::OK(); case tensorflow::DT_HALF: *type = xla::F16; return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..b5c17c5273bb15e20184b2fefd93880d4828105e --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" + +#include +#include "tensorflow/compiler/aot/runtime.h" + +namespace tensorflow { + +XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, + AllocMode alloc_mode) + : raw_function_(static_data.raw_function), + result_index_(static_data.result_index), + args_(new void*[static_data.num_args]), + temps_(new void*[static_data.num_temps]), + arg_names_(static_data.arg_names), + result_names_(static_data.result_names), + program_shape_(static_data.program_shape) { + // Allocate arg and temp buffers. + if (alloc_mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { + alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + static_data.arg_sizes, static_data.num_args, args_, + /*annotate_initialized=*/false); + } + alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + static_data.temp_sizes, static_data.num_temps, temps_, + /*annotate_initialized=*/true); + + // The runtime context is always the last arg, if it is required. + if (static_data.requires_runtime_context) { + args_[static_data.num_args - 1] = &context_; + } +} + +XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { + tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); + tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); + delete[] args_; + delete[] temps_; +} + +namespace { + +// Linear search through `names` looking for a match with `name`. Returns -1 if +// the name isn't found, or is empty. +// +// REQUIRES: `names` is a nullptr-terminated array. +int LookupNameIndex(const string& name, const char** names) { + // Hitting this assert means that there is no name-to-index data available; + // for AOT try the setting the tfcompile --gen_name_to_index flag. + assert(names != nullptr); + + constexpr int kNotFound = -1; + if (name.empty()) { + return kNotFound; + } + for (int index = 0; names[index] != nullptr; ++index) { + if (name == names[index]) { + return index; + } + } + return kNotFound; +} + +} // namespace + +int XlaCompiledCpuFunction::LookupArgIndex(const string& name) const { + return LookupNameIndex(name, arg_names_); +} + +int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const { + return LookupNameIndex(name, result_names_); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h new file mode 100644 index 0000000000000000000000000000000000000000..01e6b4c071a057429b78171b1c6ff2f38bb85590 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -0,0 +1,223 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/types.h" + +// Forward-declare, rather than include, to reduce code size for users that +// never use this functionality. +namespace xla { +class ProgramShape; +} + +namespace tensorflow { + +// Represents a function compiled by XLA, produced via either JIT or AOT. +// +// The Run method invokes the actual computation, with inputs read from arg +// buffers, and outputs written to result buffers. Each Run call may also use a +// set of temporary buffers for the computation. +// +// By default each instance of this class manages its own arg, result and temp +// buffers. The AllocMode constructor parameter may be used to modify the buffer +// allocation strategy. +// +// Under the default allocation strategy, this class is thread-compatible: +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while it +// is guaranteed that no thread may call a non-const method. +class XlaCompiledCpuFunction { + public: + // Type of the raw function, produced by either JIT or AOT. + // + // TODO(toddw): Add support for hlo profiling, and replace std::function with + // a raw function pointer, for some codesize savings. + using RawFunction = std::function; + + // StaticData represents the state necessary to run an XLA-compiled + // function. For JIT this is backed by data in XlaCompiledCpuFunctionJit; for + // AOT this is backed by data compiled into the object file. + struct StaticData { + // The raw function to call. + RawFunction raw_function; + + // Cardinality and sizes of arg and temp buffers. + const intptr_t* arg_sizes = nullptr; + size_t num_args = 0; + const intptr_t* temp_sizes = nullptr; + size_t num_temps = 0; + + // The 0-based index of the result tuple, in the temp buffers. + size_t result_index = 0; + + // Is the final arg XlaLocalRuntimeContext? + bool requires_runtime_context = false; + + // [Optional] Arrays of arg and result names. These are arrays of C-style + // strings, where the array is terminated by nullptr. + const char** arg_names = nullptr; + const char** result_names = nullptr; + + // [Optional] Arg and result shapes. + const xla::ProgramShape* program_shape = nullptr; + }; + + // AllocMode controls the buffer allocation mode. + enum class AllocMode { + // Allocate all buffers - args, results and temps. + ARGS_RESULTS_AND_TEMPS, + + // Only allocate result and temp buffers. + // Use set_arg_data to set argument buffers before Run is called. + RESULTS_AND_TEMPS_ONLY, + }; + + XlaCompiledCpuFunction( + const StaticData& static_data, + AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + virtual ~XlaCompiledCpuFunction(); + + XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; + XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete; + + // Sets the intra-op thread pool used to run individual ops concurrently. + void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { + run_options_.set_intra_op_thread_pool(pool); + context_.thread_pool = pool; + } + + // Runs the computation, with inputs read from arg buffers, and outputs + // written to result buffers. Returns true on success and false on failure. + bool Run() { + context_.error = false; + context_.error_msg.clear(); + raw_function_(temps_[result_index_], &run_options_, + const_cast(args_), temps_); + return !context_.error; + } + + // Returns the error message from the previous failed Run call. + const string& error_msg() const { return context_.error_msg; } + + // ------------------------------ + // Arg methods for managing input buffers. Buffers are in row-major order. + + // Returns the underlying array of argument buffers, where args()[I] is the + // buffer for the positional argument at index I. + void** args() { return args_; } + const void* const* args() const { return args_; } + + // Returns the buffer for the positional argument at the given `index`. + void* arg_data(size_t index) { return args_[index]; } + const void* arg_data(size_t index) const { return args_[index]; } + + // Sets the buffer for the positional argument at the given `index` to `data`. + // Must be called before Run to have an effect. May be called under any + // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be + // called for each positional argument, in order to set the argument buffers. + // + // Allocated memory must be aligned to the size specified by + // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in + // tensorflow/compiler/aot/runtime.h to ensure correct alignment. + // + // If StaticData.requires_runtime_context==true, the final argument is an + // XlaLocalRuntimeContext, which is managed internally by this class, and + // should not be changed. + // + // Aliasing of argument and result buffers is not allowed, and results in + // undefined behavior. + void set_arg_data(size_t index, void* data) { args_[index] = data; } + + // ------------------------------ + // Result methods for managing output buffers. Buffers are in row-major order. + // Must only be called after a successful Run call. Unlike the arg methods, + // there is no set_resultN_data method. The result buffers are managed + // internally, and may change after each call to Run. + + // Returns the underlying array of result buffers, where results()[I] is the + // buffer for the positional result at index I. + void** results() { return static_cast(temps_[result_index_]); } + const void* const* results() const { + return static_cast(temps_[result_index_]); + } + + // Returns the buffer for the positional result at the given `index`. + void* result_data(size_t index) { return results()[index]; } + const void* result_data(size_t index) const { return results()[index]; } + + // ------------------------------ + // Methods for extracting optional metadata. + + // Returns true iff data is available for the Lookup{Arg,Result}Index methods. + // E.g. the data might not be compiled into the binary for AOT. + bool HasNameIndices() const { + return arg_names_ != nullptr && result_names_ != nullptr; + } + + // Returns the 0-based index for the argument with the given `name`. + // Returns -1 if the name wasn't found, or data isn't available. + // + // The index remains constant for every instance of XlaCompiledCpuFunction + // generated from the same static data, and might not be cheap to determine. + // Recommended usage is to capture this in a variable for re-use. + int LookupArgIndex(const string& name) const; + + // Returns the 0-based index for the result with the given `name`. + // Returns -1 if the name wasn't found, or data isn't available. + // + // The index remains constant for every instance of XlaCompiledCpuFunction + // generated from the same static data, and might not be cheap to determine. + // Recommended usage is to capture this in a variable for re-use. + int LookupResultIndex(const string& name) const; + + // Returns the shape of the args and results. May return nullptr if the + // program shape isn't available. + const xla::ProgramShape* ProgramShape() const { return program_shape_; } + + private: + const RawFunction raw_function_; + const size_t result_index_; + + // Arrays of argument and temp buffers; entries in args_ may be overwritten by + // the user. + void** args_ = nullptr; + void** temps_ = nullptr; + + // Backing memory for individual arg and temp buffers. + void* alloc_args_ = nullptr; + void* alloc_temps_ = nullptr; + + // Options and context passed to the compiled function. + xla::ExecutableRunOptions run_options_; + tensorflow::XlaLocalRuntimeContext context_; + + // Optional metadata. + const char** arg_names_ = nullptr; + const char** result_names_ = nullptr; + const xla::ProgramShape* program_shape_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 8521d4167a11ebb7d8af87ca2e18e0140bc76eb9..a82ef02e32c78373ec2aa56558f525d7b825d861 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -18,12 +18,15 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" +#include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" @@ -92,7 +95,6 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) } local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), - FunctionDefLibrary{})); local_pflr_.reset(new ProcessFunctionLibraryRuntime( &device_mgr_, Env::Default(), options.graph_def_version, @@ -127,6 +129,37 @@ static Status GetFunctionBody(const NameAttrList& function, return Status::OK(); } +Status XlaCompiler::FindFunctionBody(const NameAttrList& function, + const FunctionBody** fbody) { + // The function may be in either the local_flib_runtime_ or flib_runtime_. + // Look up the function in local first and if it is not found then look up the + // function in flib_runtime_. + auto status = GetFunctionBody(function, local_flib_runtime_, fbody); + if (!status.ok()) { + if (!errors::IsNotFound(status)) { + return status; + } + TF_RETURN_WITH_CONTEXT_IF_ERROR( + GetFunctionBody(function, flib_runtime_, fbody), + "Local lookup failed with: ", status.error_message()); + } + return Status::OK(); +} + +std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { + std::unique_ptr graph(new Graph(options_.flib_def)); + CopyGraph(*fbody->graph, graph.get()); + OptimizerOptions opts; + opts.set_do_common_subexpression_elimination(true); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); + GraphOptimizer optimizer(opts); + optimizer.Optimize(flib_runtime_, flib_runtime_->env(), + /*device=*/nullptr, &graph, /*shape_map=*/nullptr); + + return graph; +} + Status XlaCompiler::CompileFunction( const XlaCompiler::CompileOptions& options, const NameAttrList& function, const std::vector& args, @@ -142,9 +175,7 @@ Status XlaCompiler::CompileFunction( } const FunctionBody* fbody; - if (!GetFunctionBody(function, local_flib_runtime_, &fbody).ok()) { - TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_, &fbody)); - } + TF_RETURN_IF_ERROR(FindFunctionBody(function, &fbody)); TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args)); @@ -181,7 +212,7 @@ namespace { Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, XlaCompilationDevice* device, FunctionLibraryRuntime* flib, int64 step_id) { - // Resource cleanup is a bit messy. XlaContext is a ref-counted resource; the + // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the // resource manager takes ownership via Create, and unrefs via Cleanup. We // explicitly add a reference to ensure the refcount at entry is maintained at // all exit points; Create and Cleanup are always called in this function. @@ -198,55 +229,12 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, step_container->name(), XlaContext::kXlaContextResourceName, xla_context)); - // Create a LocalExecutor that will own and run the graph. - // TODO(b/66947550): migrate away from using an Executor in order to guarantee - // determinism and thread-safety. - LocalExecutorParams exec_params; - exec_params.device = device; - exec_params.function_library = flib; - exec_params.create_kernel = [flib](const NodeDef& ndef, OpKernel** kernel) { - return flib->CreateKernel(ndef, kernel); - }; - exec_params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; - Executor* exec_ptr = nullptr; - TF_RETURN_IF_ERROR(NewLocalExecutor(exec_params, graph.release(), &exec_ptr)); - std::unique_ptr exec(exec_ptr); - // At this point ownership of the graph has been transferred to exec. - - // Run the graph symbolically, turning the graph into an XLA computation. - Executor::Args exec_args; - exec_args.step_id = step_id; - exec_args.step_container = step_container.get(); - - // Pushes closures to run onto `worklist`. We don't run the closures directly - // from 'runner' since that might lead to a stack overflow for large graphs. - std::deque worklist; - exec_args.runner = [&](Executor::Args::Closure c) { - worklist.push_back(std::move(c)); - }; - - // The following code assumes there is only one thread involved and no - // concurrency, because we did not provide Executor a threaded runner. Async - // ops on the XlaCompilation device must not use threads or concurrency - // internally. - bool done = false; - exec->RunAsync(exec_args, [&](const Status& s) { - status = s; - done = true; - }); - // Repeatedly run closures from the worklist until `done` is signalled. - while (!done) { - TF_RET_CHECK(!worklist.empty()); - Executor::Args::Closure& c = worklist.front(); - c(); - worklist.pop_front(); - } - TF_RETURN_WITH_CONTEXT_IF_ERROR( - status, "Conversion from TensorFlow graph to XLA computation failed."); - + GraphCompiler graph_compiler(xla_context, device, graph.get(), flib, + step_container.get()); + TF_RETURN_IF_ERROR(graph_compiler.Compile()); // Explicitly clean up the step container, to capture the cleanup status. step_container.reset(); - return status; + return Status::OK(); } // Builds XLA computations for each of the arguments to the computation. @@ -509,7 +497,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->requires_runtime_context = context->has_context_parameter(); // Tuple arguments and runtime context parameters are incompatible. - CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); + TF_RET_CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; @@ -546,7 +534,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, i < context->retvals().size(); ++i) { const XlaExpression& retval = context->retvals()[i]; if (!retval.has_constant_value()) { - CHECK_LT(computation_output, num_computation_outputs); + TF_RET_CHECK(computation_output < num_computation_outputs) + << "Computation has more outputs than expected"; OutputDescription& output = result->outputs[i]; output.is_constant = false; TF_RETURN_IF_ERROR(XLAShapeToTensorShape( diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 35159dbad4117895908584ad48878e2a989b9f40..a8882a638caf2d742bfa2b4f68140e1dc4520db1 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { - // The XlaCompiler class is responsible for compilation of a self-contained // subgraph of a TensorFlow computation using the XLA linear algebra runtime. // It does a symbolic execution of the graph starting from specific input @@ -136,6 +135,27 @@ class XlaCompiler { bool operator==(const Argument& other) const; }; + // Options pertaining to an individual call to CompileGraph() or + // CompileFunction(). + struct CompileOptions { + // If `use_tuple_arg` is true, a single tuple parameter will be used for all + // arguments; if false, each argument gets its own parameter. + bool use_tuple_arg = false; + + // If 'return_updated_values_for_all_resources' is true, then updated + // values of all resource arguments will be included in the + // 'resource_updates' of the computation, even if the resource was not + // modified by the computation. Used when compiling loop bodies to ensure + // the input and output signatures match. + bool return_updated_values_for_all_resources = false; + + // If 'resolve_compile_time_constants' is true, then outputs of a + // computation that are known to be compile-time constants will be returned + // as Tensors at compile-time, rather than as run-time outputs of the + // computation. + bool resolve_compile_time_constants = true; + }; + struct OutputDescription { // Type and shape of the output. DataType type; @@ -230,39 +250,9 @@ class XlaCompiler { }; explicit XlaCompiler(Options options); - ~XlaCompiler(); - - // Options pertaining to an individual call to CompileGraph() or - // CompileFunction(). - struct CompileOptions { - // If `use_tuple_arg` is true, a single tuple parameter will be used for all - // arguments; if false, each argument gets its own parameter. - bool use_tuple_arg = false; - - // If 'return_updated_values_for_all_resources' is true, then updated - // values of all resource resources arguments will be included in the - // 'resource_updates' of the computation, even if the resource was not - // modified by the computation. Used when compiling loop bodies to ensure - // the input and output signatures match. - bool return_updated_values_for_all_resources = false; - // If 'resolve_compile_time_constants' is true, then outputs of a - // computation that are known to be compile-time constants will be returned - // as Tensors at compile-time, rather than as run-time outputs of the - // computation. - bool resolve_compile_time_constants = true; - }; + ~XlaCompiler(); - // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation. - // `args` describes the arguments to the function, each of which must either - // be a runtime-parameter to the XLA computation, a compile-time constant, or - // a resource variable. Writes the compiled output to `result`. - // - // The generated XLA computation returns a tuple containing only the - // non-constant outputs as a function of the input arguments. Constant - // arguments are returned as host memory tensors in the output list and are - // not included in the XLA computation's outputs. The XLA computation is - // null if there are no data-dependent outputs and no side effects. Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, const std::vector& args, @@ -276,10 +266,17 @@ class XlaCompiler { const std::vector& args, CompilationResult* result); + Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func, + const std::vector& types, + const std::vector& shapes, + const std::vector& expressions, + std::vector* args); + // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. - // Channel handles can be used to communicate between different computations. - // Computations that communicate should be compiled with the same XlaCompiler. + // Channel handles can be used to communicate between different + // computations. Computations that communicate should be compiled with the + // same XlaCompiler. Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); const Options& options() const { return options_; } @@ -287,6 +284,18 @@ class XlaCompiler { FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } private: + // Sets the function body `fbody` to the one registered as `function`. + Status FindFunctionBody(const NameAttrList& function, + const FunctionBody** fbody); + + // Returns the optimized graph object in this function body. + std::unique_ptr GetGraph(const FunctionBody* fbody); + + // Graph compiler needs to know how to get an optimized graph from a function + // body. + friend class GraphCompiler; + friend class XlaCompilerTest; + Options options_; // Status set to non-OK in the constructor if initialization fails. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 531725a62335fc30086de2fe381591eb7d0976d0..93aae8485d157cd4afbf804d695d5c0ab8d7946c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" @@ -36,6 +38,37 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { + +class XlaCompilerTest : public ::testing::Test { + protected: + XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} + + void SetUp() override { + client_ = xla::ClientLibrary::LocalClientOrDie(); + + XlaOpRegistry::RegisterCompilationKernels(); + + FunctionDefLibrary flib; + flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); + } + + XlaCompiler::Options DefaultOptions() { + XlaCompiler::Options options; + options.device_type = &cpu_device_type_; + options.client = client_; + options.flib_def = flib_def_.get(); + return options; + } + + FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) { + return compiler->local_flib_def_.get(); + } + + DeviceType cpu_device_type_; + xla::Client* client_; + std::unique_ptr flib_def_; +}; + namespace { // Helper class to test the ability to pass resources through to XLA @@ -63,6 +96,7 @@ class DummyReadResourceOp : public XlaOpKernel { dummy->Unref(); ctx->SetOutput(0, ctx->Input(0)); + ctx->SetOutput(1, ctx->Input(0)); } }; @@ -80,22 +114,25 @@ class DummyReadResourceCC { if (!scope.ok()) return; scope.UpdateStatus(scope.DoShapeInference(ret)); if (!scope.ok()) return; - this->output_ = Output(ret, 0); + this->output1_ = Output(ret, 0); + this->output2_ = Output(ret, 1); } - Node* node() const { return output_.node(); } - Output output_; + Output output1_; + Output output2_; }; REGISTER_OP("DummyReadResource") .Input("input: int32") - .Output("output: int32") + .Output("output1: int32") + .Output("output2: int32") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( A dummy Op. input: dummy input. -output: dummy output. +output1: dummy output. +output2: dummy output. )doc"); REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp); @@ -125,31 +162,6 @@ REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT), REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT), DummyDuplicateOp); -class XlaCompilerTest : public ::testing::Test { - protected: - XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} - - void SetUp() override { - client_ = xla::ClientLibrary::LocalClientOrDie(); - - XlaOpRegistry::RegisterCompilationKernels(); - - FunctionDefLibrary flib; - flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); - } - - XlaCompiler::Options DefaultOptions() { - XlaCompiler::Options options; - options.device_type = &cpu_device_type_; - options.client = client_; - options.flib_def = flib_def_.get(); - return options; - } - - DeviceType cpu_device_type_; - xla::Client* client_; - std::unique_ptr flib_def_; -}; // Tests compilation and execution of an empty graph. TEST_F(XlaCompilerTest, EmptyReturnValues) { @@ -316,7 +328,8 @@ TEST_F(XlaCompilerTest, ResourceManager) { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto b = DummyReadResourceCC(scope.WithOpName("B"), a); - auto c = ops::_Retval(scope.WithOpName("C"), b.output_, 0); + auto c = ops::Add(scope.WithOpName("C"), b.output2_, b.output1_); + auto d = ops::_Retval(scope.WithOpName("D"), c, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(scope.ToGraph(graph.get())); @@ -349,6 +362,58 @@ TEST_F(XlaCompilerTest, ResourceManager) { resource->Unref(); } +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, DeterministicCompilation) { + // Builds a graph that contains a node with two output edges. The compiler + // should always traverse them in the same order. + const int64 test_count = 2; + + std::vector results(test_count); + + for (int64 i = 0; i < test_count; ++i) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::Neg(scope.WithOpName("B"), a); + auto c = ops::Neg(scope.WithOpName("C"), a); + auto d = ops::Add(scope.WithOpName("D"), b, c); + auto e = ops::_Retval(scope.WithOpName("E"), d, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the argument. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + + // Compiles the graph. + auto options = DefaultOptions(); + XlaCompiler compiler(options); + + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", + std::move(graph), args, &results[i])); + } + + for (int64 i = 1; i < test_count; ++i) { + auto m1 = + results[i - 1].computation->Snapshot().ValueOrDie()->entry().requests(); + auto m2 = + results[i].computation->Snapshot().ValueOrDie()->entry().requests(); + // Check if every entry is the same. + for (auto& entry1 : m1) { + int64 key = entry1.first; + auto value1 = entry1.second; + auto entry2 = m2.find(key); + auto value2 = entry2->second; + EXPECT_TRUE(entry2 != m2.end()); + string str1, str2; + value1.AppendToString(&str1); + value2.AppendToString(&str2); + EXPECT_EQ(str1, str2); + } + } +} + // Tests a computation that receives a TensorArray resource as input and // updates it. TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { @@ -489,5 +554,104 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { EXPECT_EQ(1, result.resource_updates.size()); } +// Tests CompileFunction with undefined function fails. +TEST_F(XlaCompilerTest, UndefinedFunctionFails) { + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + XlaCompiler::CompilationResult result; + NameAttrList name_attr; + name_attr.set_name("Function_NotDefined_"); + Status status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, + /*args=*/{}, &result); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + << status.error_message(); +} + +FunctionDef FillFn() { + return FunctionDefHelper::Define( + // Name + "FillFn", + // Args + {"x: T", "dims: int32"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + {{{"y"}, "Fill", {"dims", "x"}, {{"T", "$T"}}}}); +} + +TEST_F(XlaCompilerTest, FunctionCallWithConstants) { + // Certain operations in a function, "Fill" for example, requires the + // operator's argument to be a compile-time constant instead of a parameter. + // This testcase tests if XlaCompiler can handle such operators inside + // function calls. + XlaCompiler compiler(DefaultOptions()); + + FunctionDefLibrary flib; + *flib.add_function() = FillFn(); + + TF_ASSERT_OK(flib_def_->AddFunctionDef(FillFn())); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); + + NodeDef def; + TF_ASSERT_OK(NodeDefBuilder("fill", "FillFn", flib_def_.get()) + .Input(value.name(), 0, DT_INT32) + .Input(shape.name(), 1, DT_INT32) + .Finalize(&def)); + Status status; + Node* fill = scope.graph()->AddNode(def, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.DoShapeInference(fill)); + scope.graph()->AddEdge(value.node(), 0, fill, 0); + scope.graph()->AddEdge(shape.node(), 0, fill, 1); + + auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); + + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the argument. + std::vector args; + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", + std::move(graph), args, &result)); +} + +// Tests CompileFunction with a local function lookup failing, fails with +// informative error about both lookups. +TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { + XlaCompiler compiler(DefaultOptions()); + + auto local_flib_def = LocalFlibDef(&compiler); + TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo())); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + XlaCompiler::CompilationResult result; + NameAttrList name_attr; + name_attr.set_name("XTimesTwo"); + Status status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, + /*args=*/{}, &result); + + ASSERT_FALSE(status.ok()); + // Flib lookup failure. + EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + << status.error_message(); + // Local flib lookup failure. + EXPECT_TRUE( + StringPiece(status.error_message()).contains("Attr T is not found")) + << status.error_message(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 2df9a0ed003748512a6d173f83864aa5e0d65e35..f59b83cfdd778209935970981a1463d350a64be6 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -54,6 +54,19 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, return b->ConstantLiteral(xla::Literal::One(type)); } +xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, + DataType data_type) { + switch (data_type) { + case DT_FLOAT: + return b->ConstantR0(std::numeric_limits::epsilon()); + case DT_DOUBLE: + return b->ConstantR0(std::numeric_limits::epsilon()); + default: + LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: " + << DataTypeString(data_type); + } +} + xla::ComputationDataHandle XlaHelpers::IntegerLiteral( xla::ComputationBuilder* b, DataType data_type, int64 value) { xla::Literal literal; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index e312f2c400c4fc3d865f49c739b0d3797c213f1e..af23d20fd306c03b5e47c5ca9dd042187a2d51ed 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -48,6 +48,11 @@ class XlaHelpers { static xla::ComputationDataHandle One(xla::ComputationBuilder* b, DataType data_type); + // Returns the machine epsilon for floating-point type `data_type`, i.e., + // the difference between 1.0 and the next representable value. + static xla::ComputationDataHandle Epsilon(xla::ComputationBuilder* b, + DataType data_type); + // Returns a handle representing the given value of an integer scalar // element of data_type. // Note that unlike One and Zero, does not work on boolean types. diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..1dd454ea8d57e21526e5bcde0c8efc5514983b93 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -0,0 +1,217 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/tf2xla.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { + +// Returns a vector of positional argument buffer sizes. +xla::StatusOr> ComputeArgSizes( + const xla::ProgramShape& program_shape, bool requires_runtime_context) { + std::vector arg_sizes; + const size_t num_args = program_shape.parameters_size(); + arg_sizes.reserve(num_args); + for (int i = 0; i < num_args; ++i) { + const xla::Shape& arg_shape = program_shape.parameters(i); + if (i == num_args - 1 && requires_runtime_context) { + // If the compiled function needs an XlaLocalRuntimeContext* arg, it's + // always last, and must be represented as an opaque type. + const xla::PrimitiveType type = arg_shape.element_type(); + if (type != xla::OPAQUE) { + return errors::InvalidArgument( + "expected final context arg to be opaque, but got type: ", + xla::PrimitiveType_Name(type), ", from program shape: ", + xla::ShapeUtil::HumanString(program_shape)); + } + arg_sizes.push_back(-1); + } else { + constexpr size_t kPointerSize = sizeof(void*); + arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); + } + } + return std::move(arg_sizes); +} + +// Returns a vector of positional temporary buffer sizes. +xla::StatusOr> ComputeTempSizes( + const xla::BufferAssignment& buffer_assignment) { + const std::vector& allocations = + buffer_assignment.Allocations(); + std::vector temp_sizes; + temp_sizes.reserve(allocations.size()); + for (const xla::BufferAllocation& allocation : allocations) { + // Callers don't allocate temporary buffers for parameters. Nor for + // thread-local buffers, which are lowered to alloca. + if (allocation.is_entry_computation_parameter() || + allocation.is_thread_local()) { + temp_sizes.push_back(-1); + } else { + temp_sizes.push_back(allocation.size()); + } + } + return std::move(temp_sizes); +} + +// Returns the index of the result in the temp buffers. +xla::StatusOr ComputeResultIndex( + const xla::BufferAssignment& buffer_assignment) { + TF_ASSIGN_OR_RETURN(const xla::BufferAllocation::Slice result_slice, + buffer_assignment.GetUniqueTopLevelOutputSlice()); + return result_slice.index(); +} + +// Adapt ComputeFunctionType, which includes a final profile_counters arg, to +// RawFunction, which doesn't include that final arg. +// +// TODO(toddw): Change RawFunction and AOT to also pass the final +// profile_counters arg, and remove this adapter. +XlaCompiledCpuFunction::RawFunction RawFunctionAdapter( + xla::cpu::CpuExecutable::ComputeFunctionType compute_function) { + return [compute_function](void* result, + const xla::ExecutableRunOptions* run_options, + const void** args, void** temps) { + return compute_function(result, run_options, args, temps, + /*profile_counters=*/nullptr); + }; +} + +// Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold +// the actual strings in nonempty_names, and hold arrays of pointers in +// name_ptrs, terminated by a nullptr entry. +template +void CollectNames(const T& entries, std::vector* nonempty_names, + std::vector* name_ptrs) { + // First collect `nonempty_names`, to ensure the underlying strings won't + // change out from under us. + for (const auto& entry : entries) { + const string& name = entry.name(); + if (!name.empty()) { + nonempty_names->push_back(name); + } + } + // Now set `name_ptrs` pointing to the strings in `nonempty_names`. + name_ptrs->reserve(entries.size() + 1); // +1 for nullptr array terminator + size_t nonempty_index = 0; + for (const auto& entry : entries) { + const string& name = entry.name(); + if (!name.empty()) { + name_ptrs->push_back(nonempty_names->at(nonempty_index).c_str()); + ++nonempty_index; + } else { + name_ptrs->push_back(""); + } + } + name_ptrs->push_back(nullptr); // array terminator +} + +} // namespace + +/*static*/ xla::StatusOr> +XlaJitCompiledCpuFunction::Compile( + const GraphDef& graph_def, const tf2xla::Config& config, + const xla::ExecutableBuildOptions& build_options) { + // Convert the graph_def into an xla::Computation. + TF_ASSIGN_OR_RETURN(xla::LocalClient * client, + xla::ClientLibrary::GetOrCreateLocalClient()); + xla::Computation computation; + bool requires_runtime_context; + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla( + graph_def, config, client, &computation, &requires_runtime_context)); + + // Get and verify the program shape. + TF_ASSIGN_OR_RETURN(std::unique_ptr program_shape, + client->GetComputationShape(computation)); + if (program_shape->result().element_type() != xla::TUPLE) { + // The XlaCompiler we use to build the xla computation always generates a + // tuple result, and XlaCompiledCpuFunction relies on this for simpler + // calling semantics. + return errors::Internal( + "XlaJitCompiledCpuFunction requires the XLA result to be a tuple"); + } + // The parameter names are currently meaningless, and redundant with the rest + // of our metadata, so clear them out to avoid confusion and save space. + program_shape->clear_parameter_names(); + + // Compute arg shapes, needed to compile the executable. + std::vector arg_shapes; + arg_shapes.reserve(program_shape->parameters_size()); + for (int i = 0; i < program_shape->parameters_size(); ++i) { + arg_shapes.push_back(&program_shape->parameters(i)); + } + + // Compile the executable. The static_cast to the CpuExecutable subclass is + // necessary since the raw function and buffer assignments are only available + // there. + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + client->Compile(computation, arg_shapes, build_options)); + const xla::cpu::CpuExecutable* cpu_executable = + static_cast(executable->executable()); + XlaCompiledCpuFunction::RawFunction raw_function = + RawFunctionAdapter(cpu_executable->compute_function()); + const xla::BufferAssignment& buffer_assignment = + cpu_executable->buffer_assignment(); + + // Compute buffer sizes and the result index, needed to run the raw function. + TF_ASSIGN_OR_RETURN( + std::vector arg_sizes, + ComputeArgSizes(*program_shape, requires_runtime_context)); + TF_ASSIGN_OR_RETURN(std::vector temp_sizes, + ComputeTempSizes(buffer_assignment)); + TF_ASSIGN_OR_RETURN(size_t result_index, + ComputeResultIndex(buffer_assignment)); + + std::unique_ptr jit_unique_ptr( + new XlaJitCompiledCpuFunction); + XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get(); + jit->executable_ = std::move(executable); + jit->arg_sizes_ = std::move(arg_sizes); + jit->temp_sizes_ = std::move(temp_sizes); + jit->program_shape_ = std::move(program_shape); + jit->static_data_.raw_function = std::move(raw_function); + jit->static_data_.arg_sizes = jit->arg_sizes_.data(); + jit->static_data_.num_args = jit->arg_sizes_.size(); + jit->static_data_.temp_sizes = jit->temp_sizes_.data(); + jit->static_data_.num_temps = jit->temp_sizes_.size(); + jit->static_data_.result_index = result_index; + jit->static_data_.requires_runtime_context = requires_runtime_context; + // Optional metadata is collected and set below. + CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); + CollectNames(config.fetch(), &jit->nonempty_result_names_, + &jit->result_names_); + jit->static_data_.arg_names = jit->arg_names_.data(); + jit->static_data_.result_names = jit->result_names_.data(); + jit->static_data_.program_shape = jit->program_shape_.get(); + return std::move(jit_unique_ptr); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h new file mode 100644 index 0000000000000000000000000000000000000000..af307ae4eff74927242c4650d8a43710e991cc52 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ + +#include +#include + +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents the result of JIT compilation by XLA down to a function. This +// class holds the state necessary to create XlaCompiledCpuFunction instances, +// which are used to actually invoke the compiled computation. +// +// XlaJitCompiledCpuFunction must outlive the XlaCompiledCpuFunctions that are +// created from it. It holds state shared by all of the functions, including the +// JIT-compiled function itself, along with buffer sizes and other metadata +// necessary for execution. +class XlaJitCompiledCpuFunction { + public: + // Compile a tensorflow::GraphDef into an XlaJitCompiledCpuFunction. The given + // `config` specifies the portion of the graph to compile, via feeds and + // fetches. Each feed is a positional input argument for the compiled + // function, while each fetch is a positional output argument. + static xla::StatusOr> Compile( + const GraphDef& graph_def, const tf2xla::Config& config, + const xla::ExecutableBuildOptions& build_options); + + XlaJitCompiledCpuFunction(const XlaJitCompiledCpuFunction&) = delete; + XlaJitCompiledCpuFunction& operator=(const XlaJitCompiledCpuFunction&) = + delete; + + // Returns static data used to create an XlaCompiledCpuFunction instance, + // which represents the JIT-compiled function. The static data is unchanging + // across each instance. + const XlaCompiledCpuFunction::StaticData& StaticData() const { + return static_data_; + } + + private: + XlaJitCompiledCpuFunction() {} + + // The executable holds the underlying function. + std::unique_ptr executable_; + + // The static data is backed by the rest of the state in this class. + XlaCompiledCpuFunction::StaticData static_data_; + + // The backing arrays of arg and temp buffer sizes. + std::vector arg_sizes_; + std::vector temp_sizes_; + + // The backing arrays of arg and result names. We hold the actual strings in + // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static + // data to refer to. + std::vector nonempty_arg_names_; + std::vector nonempty_result_names_; + std::vector arg_names_; + std::vector result_names_; + + // The backing data for the program shape. + std::unique_ptr program_shape_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5bee68eefc8d9452b63113c080fc86d39550e899 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h" + +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +AttrValue TypeAttrValue(DataType type) { + AttrValue attr_value; + SetAttrValue(type, &attr_value); + return attr_value; +} + +GraphDef SumGraph() { + GraphDef graph_def; + NodeDef* x = graph_def.add_node(); + x->set_name("x"); + x->set_op("Placeholder"); + (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32); + NodeDef* y = graph_def.add_node(); + y->set_name("y"); + y->set_op("Placeholder"); + (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32); + NodeDef* sum = graph_def.add_node(); + sum->set_name("sum"); + sum->set_op("Add"); + sum->add_input("x"); + sum->add_input("y"); + (*sum->mutable_attr())["T"] = TypeAttrValue(DT_INT32); + return graph_def; +} + +tf2xla::Config SumConfig() { + tf2xla::Config config; + tf2xla::Feed* x = config.add_feed(); + x->mutable_id()->set_node_name("x"); + x->set_name("x_name"); + tf2xla::Feed* y = config.add_feed(); + y->mutable_id()->set_node_name("y"); + y->set_name("y_name"); + tf2xla::Fetch* sum = config.add_fetch(); + sum->mutable_id()->set_node_name("sum"); + sum->set_name("sum_name"); + return config; +} + +TEST(XlaJitCompiledCpuFunction, Sum) { + GraphDef graph_def = SumGraph(); + tf2xla::Config config = SumConfig(); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr jit, + XlaJitCompiledCpuFunction::Compile(graph_def, config, + xla::ExecutableBuildOptions())); + XlaCompiledCpuFunction function(jit->StaticData()); + + // Run the function and check results. + *static_cast(function.arg_data(0)) = 10; + *static_cast(function.arg_data(1)) = 32; + EXPECT_TRUE(function.Run()); + EXPECT_EQ(function.error_msg(), ""); + EXPECT_EQ(*static_cast(function.result_data(0)), 42); + + // Run the function again. + *static_cast(function.arg_data(0)) = 100; + *static_cast(function.arg_data(1)) = 320; + EXPECT_TRUE(function.Run()); + EXPECT_EQ(function.error_msg(), ""); + EXPECT_EQ(*static_cast(function.result_data(0)), 420); + + // Check name to index lookups. + EXPECT_TRUE(function.HasNameIndices()); + + EXPECT_EQ(function.LookupArgIndex("x_name"), 0); + EXPECT_EQ(function.LookupArgIndex("y_name"), 1); + EXPECT_EQ(function.LookupArgIndex(""), -1); + EXPECT_EQ(function.LookupArgIndex("x"), -1); + EXPECT_EQ(function.LookupArgIndex("y"), -1); + EXPECT_EQ(function.LookupArgIndex("sum"), -1); + EXPECT_EQ(function.LookupArgIndex("sum_name"), -1); + + EXPECT_EQ(function.LookupResultIndex("sum_name"), 0); + EXPECT_EQ(function.LookupResultIndex(""), -1); + EXPECT_EQ(function.LookupResultIndex("x"), -1); + EXPECT_EQ(function.LookupResultIndex("y"), -1); + EXPECT_EQ(function.LookupResultIndex("sum"), -1); + EXPECT_EQ(function.LookupResultIndex("x_name"), -1); + EXPECT_EQ(function.LookupResultIndex("y_name"), -1); + + // Check program shape. + using xla::ShapeUtil; + const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); + const xla::ProgramShape* program_shape = function.ProgramShape(); + ASSERT_TRUE(program_shape != nullptr); + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32)); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(1), s32)); + + const xla::Shape& result = program_shape->result(); + ASSERT_EQ(result.element_type(), xla::TUPLE); + ASSERT_EQ(ShapeUtil::TupleElementCount(result), 1); + const xla::Shape& result0 = ShapeUtil::GetTupleElementShape(result, 0); + EXPECT_TRUE(ShapeUtil::Compatible(result0, s32)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 1a8d03757a2b9bdee339ecb951a67528719314d4..21448686463bddd719340715bcf80987ef332caf 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -45,17 +45,16 @@ extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" extern const char* const DEVICE_XLA_CPU; extern const char* const DEVICE_XLA_GPU; -constexpr std::array kIntTypes = {{DT_INT32, DT_INT64}}; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE}}; -constexpr std::array kNumericTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE}}; +constexpr std::array kNumericTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE}}; -constexpr std::array kCpuAllTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kCpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; -constexpr std::array kGpuAllTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kGpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 2b142d933dbc8c5a7823f9426c423b59425a85bc..b6126981431dc9a3520b6c96321c453bc955e7c0 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -41,7 +41,9 @@ cc_library( srcs = ["padding.cc"], hdrs = ["padding.h"], deps = [ + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 387253617e4f37a1561d4659eb796a181f0b5bee..92cd8e729d659c4ff24c156d89f29275848c3cee 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -206,6 +206,7 @@ StatusOr> Client::Execute( *request.mutable_execution_options() = *execution_options; } for (GlobalData* argument : arguments) { + CHECK(argument != nullptr) << "Argument pointers must not be null."; *request.add_arguments() = argument->handle(); } @@ -241,9 +242,6 @@ StatusOr>> Client::ExecuteParallel( for (GlobalData* argument : computation.arguments) { *single_request.add_arguments() = argument->handle(); } - if (computation.device_handle != nullptr) { - *single_request.mutable_device_handle() = *computation.device_handle; - } *single_request.mutable_execution_options() = computation.execution_options; *request.add_requests() = single_request; } diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index e72816a6217afd6a827642bbe3aa205409ef5718..a716159f9e74041c4823ad20b46fa94c2d7b9d8c 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -45,6 +45,10 @@ class Client { // * If execution_options is not nullptr, these options are passed to the // service to affect how it compiles our computation. (The pointer does not // need to live beyond this call.) + // * If execution_options.device_handles is not empty, the computation is + // executed on the devices associated with the handles by partitioning the + // computation based on the attached sharding attributes. Otherwise, a + // device is chosen by the service. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. StatusOr> Execute( @@ -54,12 +58,13 @@ class Client { ExecutionProfile* execution_profile = nullptr); // A struct to represent a computation instance to be executed. - // * If device_handle is not nullptr, the computation is executed on a device - // associated with the handle. Otherwise, a device is chosen by the service. + // * If execution_options.device_handles is not empty, the computation is + // executed on the devices associated with the handles by partitioning the + // computation based on the attached sharding attributes. Otherwise, a + // device is chosen by the service. struct ComputationInstance { const Computation& computation; std::vector arguments; - const DeviceHandle* device_handle; ExecutionOptions execution_options; ExecutionProfile* execution_profile; }; diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index a80412e951aaea74c36b3fd372c24fe6ec8f9fb8..dcbdb3525e8d4f397a9934f2658c7cc72b9144da 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -489,6 +489,16 @@ ComputationDataHandle ComputationBuilder::Collapse( } std::unique_ptr original_shape = shape_or_status.ConsumeValueOrDie(); + VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); + VLOG(3) << "dims to collapse: " + << tensorflow::str_util::Join(dims_to_collapse, ","); + + if (dims_to_collapse.size() <= 1) { + // Not collapsing anything, trivially we can return the operand versus + // enqueueing a trivial reshape. + return operand; + } + std::vector new_sizes; for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) { @@ -498,6 +508,9 @@ ComputationDataHandle ComputationBuilder::Collapse( } } + VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") + << "]"; + return Reshape(operand, new_sizes); } @@ -942,21 +955,39 @@ ComputationDataHandle ComputationBuilder::Min( return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions); } -ComputationDataHandle ComputationBuilder::LogicalAnd( +ComputationDataHandle ComputationBuilder::And( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LOGICAL_AND, lhs, rhs, broadcast_dimensions); + return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions); } -ComputationDataHandle ComputationBuilder::LogicalOr( +ComputationDataHandle ComputationBuilder::Or( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LOGICAL_OR, lhs, rhs, broadcast_dimensions); + return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions); } -ComputationDataHandle ComputationBuilder::LogicalNot( +ComputationDataHandle ComputationBuilder::Not( const ComputationDataHandle& operand) { - return UnaryOp(UNOP_LOGICAL_NOT, operand); + return UnaryOp(UNOP_NOT, operand); +} + +ComputationDataHandle ComputationBuilder::ShiftLeft( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::ShiftRightArithmetic( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::ShiftRightLogical( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions); } ComputationDataHandle ComputationBuilder::Abs( @@ -1433,10 +1464,20 @@ ComputationDataHandle ComputationBuilder::ReduceWindow( return ComputationDataHandle(); } - return ReduceWindowWithGeneralPadding( - operand, init_value, computation, window_dimensions, window_strides, + Status padding_valid = + ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()), + window_dimensions, window_strides); + if (!padding_valid.ok()) { + first_error_ = padding_valid; + return ComputationDataHandle(); + } + + std::vector> padding_values = MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding)); + window_dimensions, window_strides, padding); + return ReduceWindowWithGeneralPadding(operand, init_value, computation, + window_dimensions, window_strides, + padding_values); } ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( @@ -1739,8 +1780,10 @@ void ComputationBuilder::SetDeviceAssignment( /* static */ ConvolutionDimensionNumbers ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(kConvBatchDimension); - dimension_numbers.set_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_input_batch_dimension(kConvBatchDimension); + dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_output_batch_dimension(kConvBatchDimension); + dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); dimension_numbers.set_kernel_output_feature_dimension( kConvKernelOutputDimension); dimension_numbers.set_kernel_input_feature_dimension( @@ -1754,15 +1797,17 @@ ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { /* static */ StatusOr ComputationBuilder::CreateConvDimensionNumbers( - int64 batch, int64 feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 output_batch, + int64 output_feature, int64 first_spatial, int64 second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial) { - if (std::set({batch, feature, first_spatial, second_spatial}).size() != - 4) { + if (std::set( + {input_batch, input_feature, first_spatial, second_spatial}) + .size() != 4) { return FailedPrecondition( "dimension numbers for the input are not unique: (%lld, %lld, %lld, " "%lld)", - batch, feature, first_spatial, second_spatial); + input_batch, input_feature, first_spatial, second_spatial); } if (std::set({kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial}) @@ -1773,9 +1818,19 @@ ComputationBuilder::CreateConvDimensionNumbers( kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial); } + if (std::set( + {output_batch, output_feature, first_spatial, second_spatial}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the output are not unique: (%lld, %lld, %lld, " + "%lld)", + output_batch, output_feature, first_spatial, second_spatial); + } ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(batch); - dimension_numbers.set_feature_dimension(feature); + dimension_numbers.set_input_batch_dimension(input_batch); + dimension_numbers.set_input_feature_dimension(input_feature); + dimension_numbers.set_output_batch_dimension(output_batch); + dimension_numbers.set_output_feature_dimension(output_feature); dimension_numbers.add_spatial_dimensions(first_spatial); dimension_numbers.add_spatial_dimensions(second_spatial); dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 73972c1290f5f2dda13cece9df09782c4ab0b709..cdd9c8847f56e25bcb807a9cf0631e72bf4355ee 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -201,6 +201,16 @@ class ComputationBuilder { // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must // be a consecutive, in-order subsequence of the operand dimensions. // + // Note that collapsing a single dimension does nothing: + // + // {256} collapsing {0} => {256} + // {1} collapsing {0} => {1} + // + // Collapsing multiple dimensions produces a single result dimension: + // + // {256, 2} collapsing {0,1} => {512} + // {256, 2, 3} collapsing {0,1} => {512, 3} + // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. ComputationDataHandle Collapse(const ComputationDataHandle& operand, @@ -344,7 +354,8 @@ class ComputationBuilder { // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an // error if either the input or the weight dimension numbers have conflicts. static StatusOr CreateConvDimensionNumbers( - int64 batch, int64 feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 output_batch, + int64 output_feature, int64 first_spatial, int64 second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial); @@ -451,15 +462,25 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice broadcast_dimensions = {}); // Element-wise logical operators - ComputationDataHandle LogicalAnd( + ComputationDataHandle And( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle LogicalOr( + ComputationDataHandle Or( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle LogicalNot(const ComputationDataHandle& lhs); + ComputationDataHandle Not(const ComputationDataHandle& operand); + + ComputationDataHandle ShiftLeft( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + ComputationDataHandle ShiftRightArithmetic( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + ComputationDataHandle ShiftRightLogical( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 969b0eee1d195a36728f16a598add4b3b850ed60..24048a1e5a782661ba577ba50e3b5b2914f17c0a 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -89,24 +89,24 @@ Computation CreateScalarMinComputation(PrimitiveType type, const ComputationDataHandle& rhs) { return b->Min(lhs, rhs); }); } -Computation CreateScalarLogicalAndComputation(ComputationBuilder* builder) { +Computation CreateScalarAndComputation(ComputationBuilder* builder) { return CreateScalarComputation( - "logical_and", PRED, builder, + "and", PRED, builder, [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->LogicalAnd(lhs, rhs); }); + const ComputationDataHandle& rhs) { return b->And(lhs, rhs); }); } -Computation CreateScalarLogicalOrComputation(ComputationBuilder* builder) { +Computation CreateScalarOrComputation(ComputationBuilder* builder) { return CreateScalarComputation( - "logical_or", PRED, builder, + "or", PRED, builder, [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->LogicalOr(lhs, rhs); }); + const ComputationDataHandle& rhs) { return b->Or(lhs, rhs); }); } StatusOr Any(const ComputationDataHandle& predicates, ComputationBuilder* builder) { auto f = builder->ConstantR0(false); - Computation logical_or = CreateScalarLogicalOrComputation(builder); + Computation logical_or = CreateScalarOrComputation(builder); TF_ASSIGN_OR_RETURN(std::unique_ptr predicates_shape, builder->GetShape(predicates)); std::vector all_dimensions(ShapeUtil::Rank(*predicates_shape)); diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index f43d35fe4a52016d4054af28835d6b66a35217d4..ae89784bc227d837cf15f0a89687dd00dccc2745 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -45,10 +45,10 @@ Computation CreateScalarMinComputation(PrimitiveType type, ComputationBuilder* builder); // Creates a scalar logical AND computation and returns it. -Computation CreateScalarLogicalAndComputation(ComputationBuilder* builder); +Computation CreateScalarAndComputation(ComputationBuilder* builder); // Creates a scalar logical OR computation and returns it. -Computation CreateScalarLogicalOrComputation(ComputationBuilder* builder); +Computation CreateScalarOrComputation(ComputationBuilder* builder); // Returns whether any predicate in "predicates" is set. // diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index d45252d0f9e14377da523d3e82257d44ba772cf3..c885b815ebef60bbabfdbd97642d0be9bbbf49e8 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -283,11 +283,10 @@ StatusOr> LocalClient::Compile( int device_ordinal = options.device_ordinal() == -1 ? default_device_ordinal() : options.device_ordinal(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - local_service_->CompileExecutable(computation.handle(), argument_layouts, - options.result_layout(), device_ordinal, - options.has_hybrid_result())); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + local_service_->CompileExecutable( + computation.handle(), argument_layouts, + options.result_layout(), device_ordinal)); return WrapUnique(new LocalExecutable(std::move(executable), local_service_->mutable_backend(), device_ordinal, options)); diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 0b18d8946a2e62a810f875b4d79fd5375e787487..6a9cf466ac0a43ce214ef0e6aae9e6295f137b0f 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -17,17 +17,34 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" namespace xla { +Status ValidatePaddingValues( + tensorflow::gtl::ArraySlice input_dimensions, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides) { + bool ok = input_dimensions.size() == window_dimensions.size() && + input_dimensions.size() == window_strides.size(); + if (!ok) { + return InvalidArgument( + "Want input dimensions size %zu = window dimensions size %zu = window " + "strides size %zu", + input_dimensions.size(), window_dimensions.size(), + window_strides.size()); + } + return Status::OK(); +} + std::vector> MakePadding( tensorflow::gtl::ArraySlice input_dimensions, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - CHECK_EQ(input_dimensions.size(), window_dimensions.size()); - CHECK_EQ(input_dimensions.size(), window_strides.size()); + TF_CHECK_OK(ValidatePaddingValues(input_dimensions, window_dimensions, + window_strides)); std::vector> low_high_padding; switch (padding) { case Padding::kValid: diff --git a/tensorflow/compiler/xla/client/padding.h b/tensorflow/compiler/xla/client/padding.h index dce2d87e8da8b3d9fd138a712c459ea0081372e0..e23b0b3a90a091bf80973525810793c3eda4a036 100644 --- a/tensorflow/compiler/xla/client/padding.h +++ b/tensorflow/compiler/xla/client/padding.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -37,6 +38,14 @@ enum class Padding { kValid, }; +// Validates that the slices are acceptable for determining padding -- this can +// be used to check the preconditions of MakePadding below to produce an error +// message that can be returned to the user. +Status ValidatePaddingValues( + tensorflow::gtl::ArraySlice input_dimensions, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides); + // Returns the padding needed for the base area, given the base area dimensions, // window dimensions, strides, and the type of padding. // @@ -51,7 +60,7 @@ enum class Padding { std::vector> MakePadding( tensorflow::gtl::ArraySlice input_dimensions, tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice strides, Padding padding); + tensorflow::gtl::ArraySlice window_strides, Padding padding); } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index cdc4139cd69c3d6eb4afc2e5d25f9446ffad0a11..c032cb8dc5adcbef9ffa64aa1e05bb5ccb49fc6a 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -51,21 +51,39 @@ StatusOr ToJson(const tensorflow::protobuf::Message& message) { return json_output; } -Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, - const string& directory, const string& file_name) { - TF_ASSIGN_OR_RETURN(const string json_output, ToJson(message)); +namespace { - tensorflow::Env* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); - string safe_file_name = file_name + ".json"; +string SanitizeFilename(const string& file_name) { + string safe_file_name = file_name; for (char& c : safe_file_name) { if (c == '/' || c == '\\') { c = '_'; } } + return safe_file_name; +} + +} // namespace + +Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, + const string& directory, const string& file_name) { + TF_ASSIGN_OR_RETURN(const string json_output, ToJson(message)); + + tensorflow::Env* env = tensorflow::Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + string safe_file_name = SanitizeFileName(file_name) + ".json"; const string path = tensorflow::io::JoinPath(directory, safe_file_name); return tensorflow::WriteStringToFile(env, path, json_output); } +Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, + const string& directory, const string& file_name) { + tensorflow::Env* env = tensorflow::Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + string safe_file_name = SanitizeFileName(file_name) + ".pb"; + const string path = tensorflow::io::JoinPath(directory, safe_file_name); + return tensorflow::WriteBinaryProto(env, path, message); +} + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 1a895c3585902e8fbc0d20475c2817ef4caa4c71..7accb22e0c7720d5af896f8ca833ee26175fb89f 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -35,10 +35,12 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, // Returns 'message' as a JSON string. StatusOr ToJson(const tensorflow::protobuf::Message& message); -// Converts 'message' to JSON, and dumps it to the path formed by joining -// 'directory/file_name.json'. The 'directory' is recursively created if it -// doesn't already exist, and the 'file_name' is sanitized by replacing illegal -// characters with underscore '_'. +// Writes the given message in binary proto or JSON format to the path formed by +// joining 'directory/file_name.pb' (or file_name.json). The 'directory' is +// recursively created if it doesn't already exist, and the 'file_name' is +// sanitized by replacing illegal characters with underscore '_'. +Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, + const string& directory, const string& file_name); Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name); diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 35b5e8cd52ab0ec21a4bd2df3e9fa0538ae60816..eb6a71242ffa1499876b90f14f8a60ffdbdd069c 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -322,8 +322,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { // Set the convolution dimension numbers. ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(2); - dimension_numbers.set_feature_dimension(0); + dimension_numbers.set_input_batch_dimension(2); + dimension_numbers.set_input_feature_dimension(0); + dimension_numbers.set_output_batch_dimension(2); + dimension_numbers.set_output_feature_dimension(0); dimension_numbers.add_spatial_dimensions(1); dimension_numbers.add_spatial_dimensions(3); dimension_numbers.set_kernel_output_feature_dimension(0); @@ -374,8 +376,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { // Set the convolution dimension numbers. ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(2); - dimension_numbers.set_feature_dimension(0); + dimension_numbers.set_input_batch_dimension(2); + dimension_numbers.set_input_feature_dimension(0); + dimension_numbers.set_output_batch_dimension(2); + dimension_numbers.set_output_feature_dimension(0); dimension_numbers.add_spatial_dimensions(1); dimension_numbers.add_spatial_dimensions(3); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ee37cc8ad2dab6f46f559f9979e325b652b1260e..ed42358e7ed552dd9d9abb997eab3e2a2214cabe 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -717,6 +717,18 @@ cc_library( ], ) +tf_cc_test( + name = "name_uniquer_test", + srcs = ["name_uniquer_test.cc"], + deps = [ + ":name_uniquer", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "liveness_util", srcs = ["liveness_util.cc"], @@ -1051,7 +1063,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1116,7 +1128,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 26f85e93b015ed4e9e71493c1e1040defdca1188..39e8430ed335806a8b71f391ecfb30e2e3716633 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -98,11 +98,11 @@ bool ReshapeIsBitcast( HloComputation* CreateScalarBinaryComputation(HloModule* module, PrimitiveType primitive_type, HloOpcode opcode) { - HloComputation::Builder b("scalar computation"); + HloComputation::Builder b("scalar_computation"); auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar lhs")); + 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar rhs")); + 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); auto scalar_op = b.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), opcode, scalar_lhs, scalar_rhs)); @@ -912,9 +912,10 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // A Broadcast that feeds a unary element-wise operation can sink the // broadcast after the unary element-wise operation. TF_ASSIGN_OR_RETURN( - changed_, + bool sink_succeeded, TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); - if (changed_) { + changed_ |= sink_succeeded; + if (sink_succeeded) { return Status::OK(); } @@ -1217,9 +1218,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // A Reshape that feeds a unary element-wise operation can sink the // reshape after the unary element-wise operation. TF_ASSIGN_OR_RETURN( - changed_, + bool sink_succeeded, TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape)); - if (changed_) { + changed_ |= sink_succeeded; + if (sink_succeeded) { return Status::OK(); } @@ -1262,6 +1264,11 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( if (ShapeUtil::IsScalar(dynamic_slice->shape())) { return ReplaceInstruction(dynamic_slice, operand); } + // DynamicSlice where operand has the same size as the output and + // start_indices are all zero is simply equal to operand. + if (IsAll(start_indices, 0) && SameShape(operand, dynamic_slice)) { + return ReplaceInstruction(dynamic_slice, operand); + } return Status::OK(); } @@ -1280,8 +1287,7 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( // not to affect the visible behavior of this op even when the indices are out // of range. Currently dynamic-update-slice wraps out-of-range indices, so // we can only remove the op if its indices never wrap.) - if (start_indices->IsConstant() && start_indices->literal().IsAll(0) && - ShapeUtil::Compatible(dynamic_update_slice->shape(), update->shape())) { + if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) { return ReplaceInstruction(dynamic_update_slice, update); } return Status::OK(); @@ -1505,7 +1511,10 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // still convert Conv into more efficient Matmul with operand transposition // (such as the transposition flags in cuBLAS SGEMM). if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) || - input_shape.layout().minor_to_major(0) != dnums.feature_dimension() || + input_shape.layout().minor_to_major(0) != + dnums.input_feature_dimension() || + convolution_shape.layout().minor_to_major(0) != + dnums.output_feature_dimension() || // The input feature dimension should come later in the minor-to-major // order. (PositionInContainer(filter_shape.layout().minor_to_major(), @@ -1524,14 +1533,14 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // Replace it with a dot, with bitcasts around it to get the right shape. const int64 input_channels = - input_shape.dimensions(dnums.feature_dimension()); + input_shape.dimensions(dnums.input_feature_dimension()); const int64 output_channels = filter_shape.dimensions(dnums.kernel_output_feature_dimension()); // Computes the product of the non-feature dimensions. int64 conv_width = 1; for (int i = 0; i < input_shape.dimensions_size(); ++i) { - if (i != dnums.feature_dimension()) { + if (i != dnums.input_feature_dimension()) { conv_width *= input_shape.dimensions(i); } } @@ -1782,7 +1791,7 @@ static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { // Tries to determine the number of times the given loop executes. Currently // simply returns 0, 1, or "can't tell" (nullopt). -static optional GetLoopTripCount(const HloInstruction* while_op) { +static optional GetLoopTripCount(HloInstruction* while_op) { CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); VLOG(2) << "Getting trip count for loop " << while_op->ToString(); @@ -1803,10 +1812,10 @@ static optional GetLoopTripCount(const HloInstruction* while_op) { // compute how many times the loop executes. Start by computing the induction // variable's initial value. HloEvaluator evaluator; - auto* while_init = while_op->operand(0); - auto* indvar_init = while_init->operand(*indvar_tuple_idx); + auto* while_init = while_op->mutable_operand(0); + auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init->Clone().get()); + evaluator.Evaluate(indvar_init); if (!indvar_init_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable init: " << indvar_init_result.status(); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index cf97a261da957353de852427b3a7394e5f511d13..af502206e2ba85d89e208e0b8697273d2bf9b7ab 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -47,7 +47,7 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloTestBase { +class AlgebraicSimplifierTest : public HloVerifiedTestBase { public: // Makes a computation that contains a loop that runs num_iters times. HloComputation* MakeSimpleLoop(HloModule* module, int num_iters); @@ -1077,6 +1077,54 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { op::Maximum(op::Reshape(param), zero)); } +// Regression test for a bug where if we failed to sink a reshape, we'd set the +// 'changed' bit in AlgebraicSimplifier to false. +TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { + HloComputation::Builder builder(TestName()); + + // This add (param0 + 0) can be simplified. + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")), + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{0, 0}, {0, 0}}))))); + + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); +} + +// Regression test for a bug where if we failed to sink a reshape, we'd set the +// 'changed' bit in AlgebraicSimplifier to false. +TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { + HloComputation::Builder builder(TestName()); + + // This add (param0 + 0) can be simplified. + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")), + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{0, 0}, {0, 0}}))))); + + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0})); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -1530,7 +1578,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { for (int i = 0; i < strlen(options.dim_order); ++i) { char ch = options.dim_order[i]; if (ch == 'N') { - dnums.set_batch_dimension(i); + dnums.set_input_batch_dimension(i); + dnums.set_output_batch_dimension(i); in_dims.push_back(options.in_batch); } else if (ch == 'H') { dnums.set_spatial_dimensions(0, i); @@ -1539,7 +1588,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { dnums.set_spatial_dimensions(1, i); in_dims.push_back(options.in_width); } else if (ch == 'C') { - dnums.set_feature_dimension(i); + dnums.set_input_feature_dimension(i); + dnums.set_output_feature_dimension(i); in_dims.push_back(options.in_channels); in_channel_idx = i; } @@ -2165,6 +2215,29 @@ TEST_F(AlgebraicSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); } +// A dynamic-slice is trivial if its start indices are all zeroes and the size +// of its input equals the size of its output. In this case, the dynamic slice +// is equal to its input. +TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { + HloComputation::Builder builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); + builder.AddInstruction(HloInstruction::CreateDynamicSlice( + shape, + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "slice_from")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))), + /*slice_sizes=*/{10, 100, 1000})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Parameter()); +} + // A dynamic-update-slice is trivial if its start indices are all zeroes and the // size of its "update" equals the size of its output. In this case, the // dynamic-update-slice is equal to its update. diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 688aff89125ce3e30be8918a9dfe9f17e22e6243..08a53af8baa3f250919517c87c023c329b129024 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -320,6 +320,13 @@ class BufferAssignment { const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const; + // Returns true if the top-level buffers of hlo_a and hlo_b are the same. + // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b). + bool SharesTopLevelSlice(const HloInstruction* hlo_a, + const HloInstruction* hlo_b) const { + return SharesSliceAtIndex(hlo_a, {}, hlo_b, {}); + } + // Returns the underlying points-to analysis used for this assignment. const TuplePointsToAnalysis& points_to_analysis() const { return liveness_->points_to_analysis(); diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index b3784c36ff68a175c2f85b6e49419cc090766b80..a5b392cbc33c12c3255f3c06e9842fc116e672e5 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -69,7 +69,10 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.has_sender) { - return FailedPrecondition("channel handle is already used by a sender"); + return FailedPrecondition( + "when registering send, passed a channel handle that is already used " + "by a sender: %lld", + handle.handle()); } channel.has_sender = true; return Status::OK(); @@ -82,7 +85,10 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { Channel& channel = opaque_to_channel_[handle.handle()]; // TODO(b/33942691): Allow more than 1 receivers for broadcast. if (channel.receiver_count >= 1) { - return FailedPrecondition("channel handle is already used by a receiver"); + return FailedPrecondition( + "when registering recv, passed a channel handle that is already used " + "by a receiver: %lld", + handle.handle()); } channel.receiver_count += 1; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index c95670b1954bada51488a8b3722ca911b98b69a2..9e96898d9b4215e67c8686d372e4b4e6edd1d88b 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -101,8 +101,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, - /*has_hybrid_result=*/false)); + &execution_options)); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, computation_tracker_.BuildHloModule( diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index f71b2b6b9c65c63e6ca211004b1df5cc39aef5fa..3b1900428af1863c73efe67c27061d979557b3a4 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -58,7 +58,8 @@ Compiler::GetPlatformCompilers() { LazyInitMutex(); tensorflow::mutex_lock lock(*platform_compiler_mutex_); auto* factories = GetPlatformCompilerFactories(); - CHECK(factories->find(platform_id) == factories->end()); + CHECK(factories->find(platform_id) == factories->end()) + << "Compiler factory already registered for platform"; (*factories)[platform_id] = std::move(compiler_factory); } diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index d5bd9214be44f4abd5f672168335ae1a259c9118..4c2d9600d909e82dcb62f508a10445c08c1cdee6 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -114,7 +114,8 @@ class Compiler { // sequence of executable objects. virtual StatusOr>> Compile( std::vector> modules, - std::vector stream_exec) = 0; + std::vector> + stream_exec) = 0; // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index a2969d23d6dac2816080ec7e446571588d87662a..c71eca0d394830904e37b41ae1499678e75e062f 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -48,6 +48,29 @@ cc_library( alwayslink = True, # Contains per-platform transfer manager registration ) +cc_library( + name = "external_constant_pool", + srcs = ["external_constant_pool.cc"], + hdrs = ["external_constant_pool.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "external_constant_pool_test", + srcs = ["external_constant_pool_test.cc"], + deps = [ + ":external_constant_pool", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "cpu_compiler", srcs = ["cpu_compiler.cc"], @@ -64,6 +87,7 @@ cc_library( ":ir_emitter", ":layout_assignment", ":parallel_cpu_executable", + ":parallel_task_assignment", ":simple_orc_jit", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -130,7 +154,9 @@ cc_library( ":cpu_runtime_neon", ":cpu_runtime_sse4_1", ":disassembler", + ":external_constant_pool", ":runtime_conv2d", + ":runtime_fork_join", ":runtime_matmul", ":runtime_single_threaded_conv2d", ":runtime_single_threaded_matmul", @@ -217,7 +243,9 @@ cc_library( ":cpu_options", ":cpu_runtime", ":dot_op_emitter", + ":external_constant_pool", ":ir_emission_utils", + ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -238,6 +266,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -479,6 +508,20 @@ cc_library( ], ) +cc_library( + name = "runtime_fork_join", + srcs = ["runtime_fork_join.cc"], + hdrs = ["runtime_fork_join.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//third_party/eigen3", + ], +) + tf_cc_test( name = "cpu_runtime_test", srcs = ["cpu_runtime_test.cc"], @@ -543,6 +586,7 @@ cc_library( ], deps = [ ":ir_emission_utils", + ":parallel_task_assignment", ":shape_partition", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -652,6 +696,19 @@ tf_cc_test( ], ) +cc_library( + name = "parallel_task_assignment", + srcs = ["parallel_task_assignment.cc"], + hdrs = ["parallel_task_assignment.h"], + deps = [ + ":ir_emission_utils", + ":shape_partition", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cost_analysis", + "//tensorflow/compiler/xla/service:hlo_pass", + ], +) + cc_library( name = "cpu_options", srcs = ["cpu_options.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 069979c6611e90ed2d95cbbe341198577cdf56cf..44cd2171afdc6eecc22f3f920276a4d95f930573 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -36,8 +36,8 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { !PotentiallyImplementedAsEigenConvolution(*hlo)) { const ConvolutionDimensionNumbers& dnums = hlo->convolution_dimension_numbers(); - auto batch_dim = dnums.batch_dimension(); - auto feature_dim = dnums.feature_dimension(); + auto input_batch_dim = dnums.input_batch_dimension(); + auto input_feature_dim = dnums.input_feature_dimension(); auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension(); auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension(); @@ -59,15 +59,16 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { std::vector new_input_dim_order(num_dims); std::vector new_input_dims(num_dims); - new_input_dim_order[0] = batch_dim; - new_input_dims[0] = input->shape().dimensions(batch_dim); + new_input_dim_order[0] = input_batch_dim; + new_input_dims[0] = input->shape().dimensions(input_batch_dim); for (int i = 0; i < num_spatial_dims; ++i) { new_input_dim_order[i + 1] = dnums.spatial_dimensions(i); new_input_dims[i + 1] = input->shape().dimensions(dnums.spatial_dimensions(i)); } - new_input_dim_order[num_dims - 1] = feature_dim; - new_input_dims[num_dims - 1] = input->shape().dimensions(feature_dim); + new_input_dim_order[num_dims - 1] = input_feature_dim; + new_input_dims[num_dims - 1] = + input->shape().dimensions(input_feature_dim); Shape new_input_shape = ShapeUtil::MakeShape(input->shape().element_type(), new_input_dims); @@ -98,22 +99,26 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { new_kernel_dim_order)); std::vector new_conv_dims(num_dims); - new_conv_dims[0] = hlo->shape().dimensions(batch_dim); + auto output_batch_dim = dnums.output_batch_dimension(); + auto output_feature_dim = dnums.output_feature_dimension(); + new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim); for (int i = 0; i < num_spatial_dims; ++i) { new_conv_dims[i + 1] = hlo->shape().dimensions(dnums.spatial_dimensions(i)); } - new_conv_dims[num_dims - 1] = hlo->shape().dimensions(feature_dim); + new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim); Shape new_conv_shape = ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims); ConvolutionDimensionNumbers new_dnums; - new_dnums.set_batch_dimension(0); + new_dnums.set_input_batch_dimension(0); + new_dnums.set_output_batch_dimension(0); for (int i = 0; i < num_spatial_dims; ++i) { new_dnums.add_spatial_dimensions(i + 1); new_dnums.add_kernel_spatial_dimensions(i); } - new_dnums.set_feature_dimension(num_dims - 1); + new_dnums.set_input_feature_dimension(num_dims - 1); + new_dnums.set_output_feature_dimension(num_dims - 1); new_dnums.set_kernel_input_feature_dimension(num_dims - 2); new_dnums.set_kernel_output_feature_dimension(num_dims - 1); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 9e8b785f30559f493bcec546e0612f2290af031d..d593ba26b655d00a0f0f0b9a94c9e62fa1835080 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -67,10 +67,12 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize)))); ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(1); + dnums.set_input_batch_dimension(1); + dnums.set_output_batch_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); - dnums.set_feature_dimension(0); + dnums.set_input_feature_dimension(0); + dnums.set_output_feature_dimension(0); dnums.add_kernel_spatial_dimensions(2); dnums.add_kernel_spatial_dimensions(3); dnums.set_kernel_input_feature_dimension(1); @@ -121,10 +123,12 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount)))); ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); dnums.add_kernel_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(2); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index c30f9ea1947204adbe7aad31fcdf04c9e345c882..ce4d109214b9ad236fbf125179276bf53f4cbf57 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -58,6 +58,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -86,10 +87,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/env.h" namespace se = ::perftools::gputools; @@ -250,7 +249,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* module) { +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Optimization pipeline. HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(ShapeSizeBytesFunction()); @@ -271,6 +270,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { { auto& pass = pipeline.AddPass>("simplification"); + pass.AddInvariantChecker(ShapeSizeBytesFunction()); + pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, @@ -316,6 +317,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { if (options::CpuParallelBackendRequested(module->config())) { pipeline.AddPass(max_parallelism, ShapeSizeBytesFunction()); + } else if (!is_aot_compile) { + // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module. + // Note this is not run for AOT because it would bring in thread pool + // and thread synchronization dependencies which would likely increase + // binary size (and most AOT applications are single-threaded). + // TODO(29630486) Support multi-threaded AOT. + pipeline.AddPass(max_parallelism, + ShapeSizeBytesFunction(), module); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -367,68 +376,50 @@ llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) { } } -Status AppendIRToFile(const string& file_name, const string& ir_module_string) { - std::unique_ptr f; - TF_RETURN_IF_ERROR( - tensorflow::Env::Default()->NewWritableFile(file_name, &f)); - TF_RETURN_IF_ERROR(f->Append(ir_module_string)); - TF_RETURN_IF_ERROR(f->Close()); - return Status::OK(); -} - Status InitializeModuleHooks( - const HloModule& module, + const HloModule& hlo_module, const LLVMCompiler::ModuleHook& user_pre_optimization_hook, const LLVMCompiler::ModuleHook& user_post_optimization_hook, LLVMCompiler::ModuleHook* pre_optimization_ir_hook, LLVMCompiler::ModuleHook* post_optimization_ir_hook) { - const string& dump_ir_to = module.config().debug_options().xla_dump_ir_to(); - if (dump_ir_to.empty()) { + const string& ir_dump_directory = + hlo_module.config().debug_options().xla_dump_ir_to(); + if (ir_dump_directory.empty()) { *pre_optimization_ir_hook = user_pre_optimization_hook; *post_optimization_ir_hook = user_post_optimization_hook; return Status::OK(); } - // Initialize the output directory and create the output file names. - TF_RETURN_IF_ERROR( - tensorflow::Env::Default()->RecursivelyCreateDir(dump_ir_to)); - string safe_file_name_base = module.name(); - std::replace_if(safe_file_name_base.begin(), safe_file_name_base.end(), - [](char c) { return c == '/' || c == '\\'; }, '_'); - - string unoptimized_ir_file_name = tensorflow::io::JoinPath( - dump_ir_to, - tensorflow::strings::StrCat("ir-", safe_file_name_base, "-no-opt.ll")); - string optimized_ir_file_name = tensorflow::io::JoinPath( - dump_ir_to, - tensorflow::strings::StrCat("ir-", safe_file_name_base, "-opt.ll")); + const string& hlo_module_name = hlo_module.name(); // Create the IR hooks. If applicable, each IR hook does the following: - // * Call the user supplied module hook. - // * Write to the output directory. Files will be appended to. We still want - // to append to avoid overwriting possibly important information due to - // operator error. + // + // * Calls the user supplied module hook. + // * Writes out the IR to a file in the output directory designated by + // --xla_dump_ir_to *pre_optimization_ir_hook = - [user_pre_optimization_hook, - unoptimized_ir_file_name](const llvm::Module& module) { + [user_pre_optimization_hook, ir_dump_directory, + hlo_module_name](const llvm::Module& llvm_module) { if (user_pre_optimization_hook) { - TF_RETURN_IF_ERROR(user_pre_optimization_hook(module)); + TF_RETURN_IF_ERROR(user_pre_optimization_hook(llvm_module)); } - TF_RETURN_IF_ERROR(AppendIRToFile(unoptimized_ir_file_name, - llvm_ir::DumpModuleToString(module))); - return Status::OK(); + return llvm_ir::DumpIRToDirectory(/*directory_name=*/ir_dump_directory, + /*hlo_module_name=*/hlo_module_name, + llvm_module, + /*optimized=*/false); }; *post_optimization_ir_hook = - [user_post_optimization_hook, - optimized_ir_file_name](const llvm::Module& module) { + [user_post_optimization_hook, ir_dump_directory, + hlo_module_name](const llvm::Module& llvm_module) { if (user_post_optimization_hook) { - TF_RETURN_IF_ERROR(user_post_optimization_hook(module)); + TF_RETURN_IF_ERROR(user_post_optimization_hook(llvm_module)); } - TF_RETURN_IF_ERROR(AppendIRToFile(optimized_ir_file_name, - llvm_ir::DumpModuleToString(module))); - return Status::OK(); + return llvm_ir::DumpIRToDirectory(/*directory_name=*/ir_dump_directory, + /*hlo_module_name=*/hlo_module_name, + llvm_module, + /*optimized=*/true); }; return Status::OK(); @@ -468,7 +459,7 @@ StatusOr> CpuCompiler::Compile( llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get())); + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); HloComputation* computation = module->entry_computation(); std::unordered_map hlo_to_profile_idx; @@ -542,7 +533,8 @@ StatusOr> CpuCompiler::Compile( } IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine()); + &hlo_to_profile_idx, jit->target_machine(), + jit->external_constant_pool()); std::unique_ptr> function_names( new std::map()); @@ -622,7 +614,8 @@ StatusOr> CpuCompiler::Compile( // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine()); + &hlo_to_profile_idx, jit->target_machine(), + jit->external_constant_pool()); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -669,7 +662,7 @@ StatusOr> CpuCompiler::Compile( StatusOr>> CpuCompiler::Compile( std::vector> modules, - std::vector stream_execs) { + std::vector> stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on CPU."); } @@ -765,7 +758,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, HloModule* module = modules[i].get(); VLOG(1) << "Compiling ahead-of-time: " << module->name(); - TF_RETURN_IF_ERROR(RunHloPasses(module)); + TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true)); TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, @@ -791,7 +784,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, } IrEmitter ir_emitter(*module, *assignment, &llvm_module, - /*hlo_to_profile_idx=*/nullptr, target_machine.get()); + /*hlo_to_profile_idx=*/nullptr, target_machine.get(), + /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index bd3541500dae9d9d59c56bfb062912a1b85c2219..d09130247421b11d6d4879466f39b89167eb9564 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -115,7 +115,8 @@ class CpuCompiler : public LLVMCompiler { StatusOr>> Compile( std::vector> modules, - std::vector stream_exec) override; + std::vector> + stream_execs) override; StatusOr>> CompileAheadOfTime(std::vector> modules, @@ -131,7 +132,7 @@ class CpuCompiler : public LLVMCompiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* module); + Status RunHloPasses(HloModule* module, bool is_aot_compile); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 0d68aa7399a9c6b8aed26a6aa5bd4e90a503a92f..238bc9b46ae2bf1b519eaf137d9ae063e769bd2e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -87,6 +87,17 @@ class CpuExecutable : public Executable { std::unique_ptr CreateCostAnalysis() const override; + // Type of the computation function we expect in the JIT. + using ComputeFunctionType = void (*)( + void* /*result*/, const ExecutableRunOptions* /*run_options*/, + const void** /*args*/, void** /*temps*/, uint64* /*profile_counters*/); + + const ComputeFunctionType& compute_function() const { + return compute_function_; + } + + const BufferAssignment& buffer_assignment() const { return *assignment_; } + private: // Allocate buffers required for execution and assign them to the elements of // "buffers". "buffers" should be sized to the number of buffers in buffer @@ -129,11 +140,6 @@ class CpuExecutable : public Executable { // positives. string ir_module_string_; - // Type of the computation function we expect in the JIT. - // void function(void* result, const void* run_options, - // const void** args_array, void** temps_array) - using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, - uint64*); ComputeFunctionType compute_function_; // Entry function name for the computation. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc index 8c827efefcbed6caf9b58a3394dafb4c82463f31..662ee609232f5582ce74f4f515637b2623175e94 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -109,34 +110,15 @@ StatusOr ParallelizationPreparation::RunParallelTaskAssignment( HloModule* module) { VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_; bool changed = false; - // Run cost analysis on entry computation. - HloCostAnalysis cost_analysis(shape_size_); + // Initialize ParallelTaskAssignment. + ParallelTaskAssignment parallel_task_assignment(max_parallelism_, shape_size_, + module); + // Assign parallel tasks to HLOs in entry computation. HloComputation* computation = module->entry_computation(); - Status cost_status = computation->root_instruction()->Accept(&cost_analysis); for (auto* instruction : computation->instructions()) { - // Currently, we do not assign parallel tasks to instructions with at least - // one of the following properties: - // *) Internal threading (library calls to kConv, kDot, and kCustomCall). - // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). - // *) Tuple-shaped. - // TODO(b/27458679) Parallelize instructions which are skipped here. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kCall || - instruction->opcode() == HloOpcode::kCustomCall || - instruction->opcode() == HloOpcode::kSelectAndScatter || - (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) || - PotentiallyImplementedAsEigenDot(*instruction) || - (instruction->opcode() == HloOpcode::kFusion && - instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || - ShapeUtil::IsTuple(instruction->shape())) { - continue; - } - // Calculate target parallel task count in [1, max_parallelism_]. - const int64 target_parallel_task_count = GetTargetParallelTaskCount( - cost_status.ok() ? &cost_analysis : nullptr, instruction); + const int64 target_parallel_task_count = + parallel_task_assignment.GetTargetParallelTaskCount(instruction); if (target_parallel_task_count == 1) { continue; } @@ -159,30 +141,6 @@ StatusOr ParallelizationPreparation::RunParallelTaskAssignment( return changed; } -int64 ParallelizationPreparation::GetTargetParallelTaskCount( - const HloCostAnalysis* cost_analysis, HloInstruction* instruction) { - // Default to a simple cost model based on hlo size and typical L2 cache size. - // Note that 'cost_analysis' can be 'nullptr' if HloCostAnalysis returns an - // error status (likely because HLOs like CustomCall are not yet implemented - // in the HloCostAnalysis). - int64 instruction_cost = shape_size_(instruction->shape()); - int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. - if (cost_analysis != nullptr) { - // Calculate the instruction cost in cycles. - // TODO(29630486) Improve on this linear cost model. - // Consider making 'min_cost_per_thread' be a function of the target - // bandwidth limit for instructions with low arithmetic complexity. - instruction_cost = 1 * cost_analysis->flop_count(*instruction) + - 2 * cost_analysis->transcendental_count(*instruction) + - 10 * cost_analysis->bytes_accessed(*instruction); - // Minimum per-thread cost is 100us of work on a 2GHz core. - min_cost_per_thread = 100000; - } - // Return target parallel task count in [1, max_parallelism_]. - return std::min(max_parallelism_, - std::max(1LL, instruction_cost / min_cost_per_thread)); -} - bool ParallelizationPreparation::OutlineParallelizableInstruction( HloInstruction* instruction) { if (instruction->outer_dimension_partitions().empty()) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h index d53fc461509cad51778dba37922212731236952f..87be758ef5d0535fdce3a65e54ce225042019cdb 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h @@ -55,12 +55,6 @@ class ParallelizationPreparation : public HloPassInterface { // Returns true on success or error status otherwise. StatusOr RunParallelTaskAssignment(HloModule* module); - // Returns the target parallel task count for 'instruction'. - // Utilizes 'cost_analysis' if non-null. - // Otherwise defaults to a simple HLO output size-based cost model. - int64 GetTargetParallelTaskCount(const HloCostAnalysis* cost_analysis, - HloInstruction* instruction); - // Outlines 'instruction' from entry computation, if it had // been assigned parallel tasks in an earlier pass through the computation. // Returns true if 'instruction' was successfully outlined, false otherwise. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index c7155b858bda5e5640e9a6719fb394ca1360d128..7908dc173d79a4a9dcb6127ac344267e27d2b5f2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -51,6 +51,9 @@ extern const char* const kAcquireOutfeedBufferForPopulationSymbolName = "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation"; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; +extern const char* const kParallelForkJoinSymbolName = + "__xla_cpu_runtime_ParallelForkJoin"; + extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 29feb7267fe97f6876827b6cbfa6217a0cecf238..2ade455b8a0a43dda8c93bbb79891439da2e4f75 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -51,6 +51,7 @@ extern const char* const kAcquireInfeedBufferForDequeueSymbolName; extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; +extern const char* const kParallelForkJoinSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9f8e5584965d0c73771750e26bd63c401d5b0c0 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { +namespace cpu { +void ExternalConstantPool::Insert(string name, const Literal& literal, + int64 alignment) { + CHECK(!ShapeUtil::IsTuple(literal.shape())); + CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); + CHECK(entries_.find(name) == entries_.end()); + + int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); + void* raw_pointer; + CHECK_EQ( + posix_memalign(&raw_pointer, std::max(alignment, sizeof(void*)), + literal_size), + 0) + << "failed to allocate " << literal_size << " bytes with alignment of " + << alignment; + + std::memcpy(raw_pointer, literal.InternalData(), literal_size); + entries_.emplace(std::move(name), static_cast(raw_pointer)); +} + +const uint8* ExternalConstantPool::Find(const string& name) { + auto it = entries_.find(name); + return it == entries_.end() ? nullptr : it->second.get(); +} +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..ade28cbcbcfda05a9ad0adab1139bf316720e11f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ + +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { +namespace cpu { +// An ExternalConstantPool maintains a set of constants kept external to +// generated LLVM IR. These constants are accessed from the IR via globals with +// extern linkage. This current incarnation of ExternalConstantPool only +// supports the JIT CPU backend; the AOT backend is not supported. +// +// Implementation-wise, this is a simple wrapper around a map of strings to byte +// buffers. This simply implementation works in a JIT scenario. This class +// will have to become smarter if we decide to support external constant pools +// on AOT compiles in the future. +class ExternalConstantPool { + public: + // Inserts a buffer with the contents of `literal` into the constant pool with + // the name `name`. It is an error to try to insert two constants with the + // same `name` into the same constant pool. The buffer for literal is aligned + // to `aligment` bytes, and `alignment` must be a power of 2. + // + // The constant pool copies out the contents of `literal` into a buffer it + // owns -- it does not keep pointers to `literal`, or to memory owned by + // `literal`. + void Insert(string name, const Literal& literal, int64 alignment); + + // Find the constant with name `name` in this constant pool. If there isn't + // such constant, return nullptr. + const uint8* Find(const string& name); + + private: + // We need to `free()` pointers allocated into `entries_` since we allocate + // them with `posix_memalign`. + struct FreeDeleter { + void operator()(void* ptr) { free(ptr); } + }; + + tensorflow::gtl::FlatMap> + entries_; +}; +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9290a4e5dfc03ddb86e9d82f1f0f4f9a8ceebb88 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { +class ExternalConstantPoolTest : public ::testing::Test {}; + +template +T GetFromBuffer(const uint8* buffer, int64 index) { + T result; + std::memcpy(&result, buffer + index * sizeof(T), sizeof(T)); + return result; +} + +TEST(ExternalConstantPoolTest, Basic) { + ExternalConstantPool constant_pool; + EXPECT_EQ(constant_pool.Find("name-0"), nullptr); + const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); + constant_pool.Insert("name-0", *literal, 4); + const uint8* constant = constant_pool.Find("name-0"); + ASSERT_NE(constant, nullptr); + + EXPECT_EQ(GetFromBuffer(constant, 0), 1); + EXPECT_EQ(GetFromBuffer(constant, 1), 2); + EXPECT_EQ(GetFromBuffer(constant, 2), 3); + EXPECT_EQ(GetFromBuffer(constant, 3), 4); + + EXPECT_EQ(constant_pool.Find("name-1"), nullptr); +} + +TEST(ExternalConstantPoolTest, RowMinorLayout) { + ExternalConstantPool constant_pool; + EXPECT_EQ(constant_pool.Find("name-0"), nullptr); + const auto literal = Literal::CreateR2WithLayout( + {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); + constant_pool.Insert("name-0", *literal, 4); + const uint8* constant = constant_pool.Find("name-0"); + ASSERT_NE(constant, nullptr); + + EXPECT_EQ(GetFromBuffer(constant, 0), 1); + EXPECT_EQ(GetFromBuffer(constant, 1), 3); + EXPECT_EQ(GetFromBuffer(constant, 2), 2); + EXPECT_EQ(GetFromBuffer(constant, 3), 4); +} + +TEST(ExternalConstantPoolTest, Alignment) { + ExternalConstantPool constant_pool; + EXPECT_EQ(constant_pool.Find("name-0"), nullptr); + + for (int i = 0; i < 8; i++) { + int64 alignment = 1 << i; + string name = tensorflow::strings::StrCat("name-", i); + + const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); + constant_pool.Insert(name, *literal, alignment); + + const uint8* constant = constant_pool.Find(name); + ASSERT_NE(constant, nullptr); + EXPECT_EQ(reinterpret_cast(constant) % alignment, 0); + } +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 91b09f2472e4001d8df8aa1ce4dc2796af2a32e7..ea5b6ca4ebfd8d67681da48b0c43a95ca3685a8e 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -55,8 +55,12 @@ bool PotentiallyImplementedAsEigenConvolution( std::is_sorted(dnums.kernel_spatial_dimensions().begin(), dnums.kernel_spatial_dimensions().end()); - return dnums.batch_dimension() == 0 && - dnums.feature_dimension() == input_shape.dimensions_size() - 1 && + const Shape& output_shape = convolution.shape(); + return dnums.input_batch_dimension() == 0 && + dnums.input_feature_dimension() == input_shape.dimensions_size() - 1 && + dnums.output_batch_dimension() == 0 && + dnums.output_feature_dimension() == + output_shape.dimensions_size() - 1 && input_spatial_dims_ascending == kernel_spatial_dims_ascending && dnums.kernel_input_feature_dimension() == kernel_shape.dimensions_size() - 2 && diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 1a2302616adb3f9b0962c0f479f9856a9a25b248..52085d13763e082c153fb469e51a55ad9e12ef24 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -49,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -75,7 +77,8 @@ IrEmitter::IrEmitter( const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, const std::unordered_map* hlo_to_profile_idx, - llvm::TargetMachine* target_machine) + llvm::TargetMachine* target_machine, + ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), @@ -86,7 +89,8 @@ IrEmitter::IrEmitter( parallel_cpu_backend_( options::CpuParallelBackendRequested(hlo_module_config_)), is_top_level_computation_(false), - target_machine_features_(target_machine) { + target_machine_features_(target_machine), + external_constant_pool_(external_constant_pool) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_enable_fast_math())); @@ -183,20 +187,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { // Even though the type of params and temps is void** in the host's view, in // LLVM IR this is represented by i8*, similarly to void*. It's up to the code // to use GEPs to unravel the indirection layers. - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); - llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); - std::vector compute_function_params( - {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); - if (IsParallelContext()) { - compute_function_params.push_back(i64_ptr_type); - } - if (hlo_to_profile_idx_) { - compute_function_params.push_back(i64_ptr_type); - } llvm::FunctionType* compute_function_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/compute_function_params, + /*Params=*/GetComputeFunctionParams(), /*isVarArg=*/false); // Functions with local linkage get an inlining bonus. Because we know @@ -218,7 +211,7 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { (++arg_iter)->setName("run_options"); (++arg_iter)->setName("params"); (++arg_iter)->setName("temps"); - if (IsParallelContext()) { + if (num_dynamic_loop_bounds_ > 0) { (++arg_iter)->setName("dynamic_loop_bounds"); } if (hlo_to_profile_idx_) { @@ -272,15 +265,39 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { Status IrEmitter::HandleConstant(HloInstruction* constant, const Literal& literal) { VLOG(2) << "HandleConstant: " << constant->ToString(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/initializer->getType(), - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/initializer, - /*Name=*/""); + llvm::GlobalVariable* global_for_const; + + // We avoid creating large constants in the LLVM IR since LLVM is not + // efficient for large constant arrays. We still emit "small enough" constant + // arrays into the Ir, in the off chance the LLVM optimizer can do something + // interesting with it. + const int kMaxInternalConstantSizeInBytes = 128; + if (external_constant_pool_ && + ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { + string global_name = tensorflow::strings::StrCat( + "constant_global_", external_global_constant_counter_++); + global_for_const = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/IrShapeType(literal.shape()), + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/nullptr, + /*Name=*/AsStringRef(global_name)); + global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + external_constant_pool_->Insert(global_name, literal, + MinimumAlignmentForShape(literal.shape())); + } else { + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_); + global_for_const = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/initializer->getType(), + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/initializer, + /*Name=*/""); + global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + } emitted_value_[constant] = global_for_const; VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); VLOG(2) << " its type: " @@ -291,8 +308,7 @@ Status IrEmitter::HandleConstant(HloInstruction* constant, Status IrEmitter::HandleCopy(HloInstruction* copy) { if (ShapeUtil::IsTuple(copy->shape())) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. - TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy)); - emitted_value_[copy] = copy_value; + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); return EmitMemcpy(*(copy->operand(0)), *copy); } else { // Use the elemental emitter for non-tuple shapes. @@ -304,17 +320,23 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on // 64-bit platforms. TCMalloc returns a pointer with alignment 8 for - // allocations smaller than 16 bytes and at least alignment 16 for allocations - // greater than or equal to 16 bytes. N.B. We could improve on this lower - // bound by explicitly allocating the memory with posix_memalign. This is + // allocations smaller than kMallocAlignmentThreshold bytes and at least + // alignment 16 for allocations greater than or equal to + // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound + // by explicitly allocating the memory with posix_memalign. This is // complicated by our desire to allow parameter buffers created by clients to // be consumed directly by the JIT. if (buffer_size == 0) { // No need to align empty buffers. return 1; } + + const int64 kMallocAlignmentThreshold = 512; + int pointer_size = module_->getDataLayout().getPointerSize(); - int buffer_alignment = buffer_size >= 16 ? 2 * pointer_size : 8; + int buffer_alignment = buffer_size >= kMallocAlignmentThreshold + ? 2 * pointer_size + : pointer_size; DCHECK_GT(buffer_alignment, 0); return buffer_alignment; @@ -389,10 +411,8 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred, TF_RET_CHECK(pred->shape().element_type() == PRED); if (ShapeUtil::IsTuple(select->shape())) { - TF_ASSIGN_OR_RETURN(llvm::Value * output_address, - EmitTargetAddressForOp(select)); - emitted_value_[select] = output_address; - llvm_ir::EmitTupleSelect(GetIrArrayForOp(select), GetIrArrayForOp(pred), + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select)); + llvm_ir::EmitTupleSelect(GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), GetEmittedValueFor(on_false), &ir_builder_); return Status::OK(); @@ -408,8 +428,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { // The infeed operation produces data (dequeued from the infeed queue) at this // address, which has been provided by buffer assignment. - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(infeed)); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed)); + llvm_ir::IrArray infeed_array = GetIrArrayFor(infeed); if (ShapeUtil::IsTuple(shape)) { TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape)); @@ -427,9 +447,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { ShapeUtil::GetTupleElementShape(shape, i); // Only the outer tuple buffer's target address is obtained from - // EmitTargetAddressForOp to handle the case when Infeed is the - // root instruction. Target addresses for internal elements can - // be obtained from EmitTempBufferPointer. + // GetEmittedValueFor, to handle the case when Infeed is the root + // instruction. Target addresses for internal elements can be obtained + // from EmitTempBufferPointer. llvm::Value* tuple_element_address = EmitTempBufferPointer(buffer, tuple_element_shape); @@ -439,15 +459,12 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { tuple_element_addresses.push_back(tuple_element_address); } - llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, shape), - tuple_element_addresses, &ir_builder_); + llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_); } else { - TF_RETURN_IF_ERROR( - EmitXfeedTransfer(XfeedKind::kInfeed, shape, target_address)); + TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape, + GetEmittedValueFor(infeed))); } - emitted_value_[infeed] = target_address; - return Status::OK(); } @@ -561,15 +578,12 @@ Status IrEmitter::HandleSort(HloInstruction* sort, HloInstruction* operand) { Status IrEmitter::HandleTuple( HloInstruction* tuple, tensorflow::gtl::ArraySlice operands) { - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(tuple)); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple)); std::vector base_ptrs; for (auto operand : operands) { base_ptrs.push_back(GetEmittedValueFor(operand)); } - llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, tuple->shape()), - base_ptrs, &ir_builder_); - emitted_value_[tuple] = target_address; + llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_); return Status::OK(); } @@ -584,7 +598,7 @@ Status IrEmitter::HandleMap( const llvm_ir::IrArray::Index& index) { std::vector parameter_addresses; for (const HloInstruction* operand : operands) { - const llvm_ir::IrArray& array = GetIrArrayForOp(operand); + const llvm_ir::IrArray& array = GetIrArrayFor(operand); parameter_addresses.push_back( array.EmitArrayElementAddress(index, &ir_builder_)); } @@ -680,7 +694,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window, SetToFirstInsertPoint(if_data.true_block, &ir_builder_); // We are not in the padding, so carry out the computation. - llvm_ir::IrArray input_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray input_array(GetIrArrayFor(operand)); llvm::Value* input_value_address = input_array.EmitArrayElementAddress(input_index, &ir_builder_); llvm::Value* result = EmitElementFunctionCall( @@ -817,7 +831,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); } }; - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); ir_builder_.CreateStore(operand_data, selected_value_address); @@ -860,10 +874,10 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { selected_index.push_back( ir_builder_.CreateLoad(selected_index_address_slot)); } - llvm_ir::IrArray source_array(GetIrArrayForOp(source)); + llvm_ir::IrArray source_array(GetIrArrayFor(source)); llvm::Value* source_value_address = source_array.EmitArrayElementAddress(source_index, &ir_builder_); - llvm_ir::IrArray output_array(GetIrArrayForOp(select_and_scatter)); + llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); llvm::Value* output_value_address = output_array.EmitArrayElementAddress(selected_index, &ir_builder_); llvm::Value* scatter_value = EmitElementFunctionCall( @@ -883,14 +897,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs, /*instruction=*/*dot, /*operands=*/{lhs, rhs}, /*supported_types=*/{F32, F64})); - llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs)); - llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs)); + llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); + llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); - Shape target_shape = dot->shape(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(dot)); - llvm_ir::IrArray target_array(target_address, target_shape); - AddAliasingInformationToIrArray(*dot, &target_array); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot)); + llvm_ir::IrArray target_array = GetIrArrayFor(dot); VLOG(2) << "HandleDot: "; VLOG(2) << " lhs operand: " @@ -901,13 +912,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs, << llvm_ir::DumpToString(*target_array.GetBasePointer()); // Dot operation is complicated so we delegate to a helper class. - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( + return DotOpEmitter::EmitDotOperation( *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_)); - - emitted_value_[dot] = target_address; - return Status::OK(); + hlo_module_config_); } Status IrEmitter::HandleConvolution(HloInstruction* convolution, @@ -935,21 +943,21 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, bool one_dim_convolution = lhs_shape.dimensions_size() == 3; llvm::Value* lhs_address = GetEmittedValueFor(lhs); llvm::Value* rhs_address = GetEmittedValueFor(rhs); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(convolution)); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution)); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); // Input tensor. const Shape& input_shape = convolution->operand(0)->shape(); - int64 input_batch = input_shape.dimensions(dnums.batch_dimension()); + int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension()); int64 input_rows = input_shape.dimensions(dnums.spatial_dimensions(0)); int64 input_cols = one_dim_convolution ? 1 : input_shape.dimensions(dnums.spatial_dimensions(1)); - int64 input_channels = input_shape.dimensions(dnums.feature_dimension()); + int64 input_channels = + input_shape.dimensions(dnums.input_feature_dimension()); // Kernel tensor. const Shape& kernel_shape = convolution->operand(1)->shape(); @@ -1018,35 +1026,33 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); ir_builder_.CreateCall( - conv_func, - { - GetExecutableRunOptionsArgument(), - ir_builder_.CreateBitCast(target_address, float_ptr_type), - ir_builder_.CreateBitCast(lhs_address, float_ptr_type), - ir_builder_.CreateBitCast(rhs_address, float_ptr_type), - ir_builder_.getInt64(input_batch), - ir_builder_.getInt64(input_rows), - ir_builder_.getInt64(input_cols), - ir_builder_.getInt64(input_channels), - ir_builder_.getInt64(kernel_rows), - ir_builder_.getInt64(kernel_cols), - ir_builder_.getInt64(kernel_channels), - ir_builder_.getInt64(kernel_filters), - ir_builder_.getInt64(output_rows), - ir_builder_.getInt64(output_cols), - ir_builder_.getInt64(row_stride), - ir_builder_.getInt64(col_stride), - ir_builder_.getInt64(padding_top), - ir_builder_.getInt64(padding_bottom), - ir_builder_.getInt64(padding_left), - ir_builder_.getInt64(padding_right), - ir_builder_.getInt64(lhs_row_dilation), - ir_builder_.getInt64(lhs_col_dilation), - ir_builder_.getInt64(rhs_row_dilation), - ir_builder_.getInt64(rhs_col_dilation), - }); - target_address->setName(AsStringRef(IrName(convolution))); - emitted_value_[convolution] = target_address; + conv_func, { + GetExecutableRunOptionsArgument(), + ir_builder_.CreateBitCast( + GetEmittedValueFor(convolution), float_ptr_type), + ir_builder_.CreateBitCast(lhs_address, float_ptr_type), + ir_builder_.CreateBitCast(rhs_address, float_ptr_type), + ir_builder_.getInt64(input_batch), + ir_builder_.getInt64(input_rows), + ir_builder_.getInt64(input_cols), + ir_builder_.getInt64(input_channels), + ir_builder_.getInt64(kernel_rows), + ir_builder_.getInt64(kernel_cols), + ir_builder_.getInt64(kernel_channels), + ir_builder_.getInt64(kernel_filters), + ir_builder_.getInt64(output_rows), + ir_builder_.getInt64(output_cols), + ir_builder_.getInt64(row_stride), + ir_builder_.getInt64(col_stride), + ir_builder_.getInt64(padding_top), + ir_builder_.getInt64(padding_bottom), + ir_builder_.getInt64(padding_left), + ir_builder_.getInt64(padding_right), + ir_builder_.getInt64(lhs_row_dilation), + ir_builder_.getInt64(lhs_col_dilation), + ir_builder_.getInt64(rhs_row_dilation), + ir_builder_.getInt64(rhs_col_dilation), + }); return Status::OK(); } @@ -1066,8 +1072,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, for (int i = 0; i < num_spatial_dims; ++i) { output_spatial[i] = index[dnums.spatial_dimensions(i)]; } - llvm::Value* output_feature = index[dnums.feature_dimension()]; - llvm::Value* batch = index[dnums.batch_dimension()]; + llvm::Value* output_feature = index[dnums.output_feature_dimension()]; + llvm::Value* batch = index[dnums.output_batch_dimension()]; // We will accumulate the products into this sum to calculate // the output entry at the given index. @@ -1091,8 +1097,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, } llvm::Value* input_feature = loops - .AddLoop(0, lhs->shape().dimensions(dnums.feature_dimension()), - "iz") + .AddLoop( + 0, lhs->shape().dimensions(dnums.input_feature_dimension()), + "iz") ->GetIndVarValue(); SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); @@ -1172,10 +1179,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, for (int i = 0; i < num_spatial_dims; ++i) { input_index[dnums.spatial_dimensions(i)] = input_spatial[i]; } - input_index[dnums.feature_dimension()] = input_feature; - input_index[dnums.batch_dimension()] = batch; + input_index[dnums.input_feature_dimension()] = input_feature; + input_index[dnums.input_batch_dimension()] = batch; - llvm_ir::IrArray kernel_array(GetIrArrayForOp(rhs)); + llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); llvm_ir::IrArray::Index kernel_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { kernel_index[dnums.kernel_spatial_dimensions(i)] = kernel_spatial[i]; @@ -1183,7 +1190,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; - llvm_ir::IrArray input_array(GetIrArrayForOp(lhs)); + llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); llvm::Value* product = ir_builder_.CreateFMul( input_array.EmitReadArrayElement(input_index, &ir_builder_), kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_)); @@ -1317,7 +1324,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm_ir::IrArray::Index input_index = FillReducedDimensionIndex(reduced_dims_index, index); llvm::Value* new_value = @@ -1361,9 +1368,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { mean_array, &ir_builder_) .EmitLoop(IrName(batch_norm_training, "mean_var"))); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(batch_norm_training)); - + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(batch_norm_training)); TF_ASSIGN_OR_RETURN( const BufferAllocation::Slice slice, assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0})); @@ -1393,7 +1398,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { llvm::Value* var = var_array.EmitReadArrayElement( feature_index_value, &ir_builder_); - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* input = operand_array.EmitReadArrayElement(index, &ir_builder_); @@ -1405,10 +1410,10 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { ir_builder_.CreateCall(func_llvm_sqrt, {variance_with_epsilon}); llvm::Value* normalized = ir_builder_.CreateFDiv( ir_builder_.CreateFSub(input, mean), variance_sqrt); - llvm_ir::IrArray offset_array(GetIrArrayForOp(offset)); + llvm_ir::IrArray offset_array(GetIrArrayFor(offset)); llvm::Value* offset = offset_array.EmitReadArrayElement( feature_index_value, &ir_builder_); - llvm_ir::IrArray scale_array(GetIrArrayForOp(scale)); + llvm_ir::IrArray scale_array(GetIrArrayFor(scale)); llvm::Value* scale = scale_array.EmitReadArrayElement( feature_index_value, &ir_builder_); llvm::Value* result = ir_builder_.CreateFAdd( @@ -1419,11 +1424,8 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { target_array, &ir_builder_) .EmitLoop(IrName(batch_norm_training, "normalize"))); - llvm_ir::EmitTuple( - llvm_ir::IrArray(target_address, batch_norm_training->shape()), - {normalized, mean, var}, &ir_builder_); - emitted_value_[batch_norm_training] = target_address; - + llvm_ir::EmitTuple(GetIrArrayFor(batch_norm_training), + {normalized, mean, var}, &ir_builder_); return Status::OK(); } @@ -1451,13 +1453,19 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); llvm::LoadInst* param_address_untyped = ir_builder_.CreateLoad(param_address_offset); + param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); + if (hlo_module_config_.debug_options() + .xla_llvm_enable_invariant_load_metadata()) { + // We never reassign parameters, so this load is invariant. + param_address_untyped->setMetadata( + llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); + } + llvm::Value* param_address_typed = ir_builder_.CreateBitCast( param_address_untyped, IrShapeType(param_shape)->getPointerTo()); emitted_value_[parameter] = param_address_typed; - // Parameters of different types may not alias one another. - llvm_ir::SetTbaaForInstruction(param_address_untyped, param_shape, - /*is_pointer_to=*/true); if (!ShapeUtil::IsOpaque(param_shape)) { AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); @@ -1521,11 +1529,11 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( : ir_builder->CreateFMul(lhs, rhs); }; - case HloOpcode::kLogicalAnd: + case HloOpcode::kAnd: return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, llvm::Value* rhs) { return ir_builder->CreateAnd(lhs, rhs); }; - case HloOpcode::kLogicalOr: + case HloOpcode::kOr: return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); }; @@ -1661,7 +1669,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); - llvm_ir::IrArray arg_array(GetIrArrayForOp(arg)); + llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = output_index.begin(); @@ -1774,6 +1782,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( } CHECK(!ShapeUtil::IsTuple(reduce->shape())); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce)); // We know we're not reducing over the most minor dimension, which means we // can lower the reduction loop as: @@ -1836,10 +1845,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( reduction_generator, array_index, vector_type, init_value, arg, dimensions, element_alignment)); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(reduce)); - llvm_ir::IrArray target_array(target_address, reduce->shape()); - AddAliasingInformationToIrArray(*reduce, &target_array); + llvm_ir::IrArray target_array = GetIrArrayFor(reduce); llvm::Value* output_address = target_array.EmitArrayElementAddress(array_index, &ir_builder_); EmitShardedVectorStore(output_address, accumulator, element_alignment, @@ -1871,10 +1877,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( reduction_generator, array_index, vector_type, init_value, arg, dimensions, element_alignment)); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(reduce)); - llvm_ir::IrArray target_array(target_address, reduce->shape()); - AddAliasingInformationToIrArray(*reduce, &target_array); + llvm_ir::IrArray target_array = GetIrArrayFor(reduce); llvm::Value* output_address = target_array.EmitArrayElementAddress(array_index, &ir_builder_); EmitShardedVectorStore(output_address, accumulator, element_alignment, @@ -1885,10 +1888,6 @@ StatusOr IrEmitter::EmitVectorizedReduce( ir_builder_.SetInsertPoint(outermost_loop_exit_block); } - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(reduce)); - - emitted_value_[reduce] = target_address; return true; } @@ -1945,7 +1944,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, // filled in. We fill in the rest of the dimensions with induction // Value*s taken from 'index' which iterates over the target array. // See the high-level description in the XLA documentation for details. - llvm_ir::IrArray arg_array(GetIrArrayForOp(arg)); + llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); @@ -1988,9 +1987,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { return DefaultAction(slice); } - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(slice)); - emitted_value_[slice] = target_address; + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice)); if (ShapeUtil::HasZeroElements(slice->shape())) { return Status::OK(); @@ -2062,8 +2059,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { outer_dims.push_back(memcpy_dim); } - llvm_ir::IrArray target_array(target_address, slice->shape()); - AddAliasingInformationToIrArray(*slice, &target_array); + llvm_ir::IrArray target_array = GetIrArrayFor(slice); const int64 num_outer_loops = outer_dims.size(); llvm_ir::ForLoopNest loops(IrName(slice), &ir_builder_); @@ -2081,7 +2077,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); } - llvm_ir::IrArray source_array = GetIrArrayForOp(operand); + llvm_ir::IrArray source_array = GetIrArrayFor(operand); const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice( /*shape=*/slice->shape(), /*starts=*/slice->slice_starts(), /*strides=*/slice->slice_strides(), /*builder=*/&ir_builder_); @@ -2116,126 +2112,26 @@ Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice, HloInstruction* operand, HloInstruction* /*start_indices*/) { if (ShapeUtil::IsScalar(dynamic_slice->shape())) { - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(dynamic_slice)); - target_address->setName(AsStringRef(IrName(dynamic_slice))); - emitted_value_[dynamic_slice] = target_address; + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice)); return EmitMemcpy(*operand, *dynamic_slice); } return DefaultAction(dynamic_slice); } -namespace { - -// Returns the first non-GetTupleElement ancestor instruction of 'hlo'. -// If the first non-GTE ancestor is tuple-shaped, populates 'index' with the -// (possibly nested) tuple indices used on the path from ancestor to 'hlo'. -const HloInstruction* LatestNonGteAncestorAndIndex(const HloInstruction* hlo, - ShapeIndex* index) { - if (hlo->opcode() == HloOpcode::kGetTupleElement) { - const auto* operand = LatestNonGteAncestorAndIndex(hlo->operand(0), index); - index->push_back(hlo->tuple_index()); - return operand; - } - return hlo; -} - -// Checks if we can emit code for DynamicUpdateSlice to update data in-place. -// Returns true if operand 0 of DynamicUpdateSlice and its output buffer -// share the same buffer allocation. -// Returns false otherwise. -// TODO(b/64142684) Share code with GPU implementation. -bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment, - HloInstruction* dynamic_update_slice) { - CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode()); - - // Walk DynamicUpdateSlice operand(0) to parameter and get its - // associated operand. See if it shares an allocation with this operand. - ShapeIndex index; - auto* operand = - LatestNonGteAncestorAndIndex(dynamic_update_slice->operand(0), &index); - if (operand->opcode() != HloOpcode::kParameter) { - return false; - } - - BufferAllocation::Slice operand_slice = - assignment.GetUniqueSlice(operand, index).ConsumeValueOrDie(); - - BufferAllocation::Slice dynamic_update_slice_slice = - assignment.GetUniqueTopLevelSlice(dynamic_update_slice) - .ConsumeValueOrDie(); - - return operand_slice == dynamic_update_slice_slice; -} - -} // namespace - Status IrEmitter::HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(dynamic_update_slice)); - target_address->setName(AsStringRef(IrName(dynamic_update_slice))); - emitted_value_[dynamic_update_slice] = target_address; + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); return EmitMemcpy(*update, *dynamic_update_slice); - } else if (CanUpdateDynamicSliceInPlace(assignment_, dynamic_update_slice)) { - VLOG(2) << "Emitting HandleDynamicUpdateSlice in-place."; - // DynamicUpdateSlice's operand(0) and 'fusion' output share the same - // BufferAllocation::Slice, so it is safe to emit code to update the slice - // 'in-place'. This avoids copying data outside of the slice update region. - // TODO(b/64142684) Implement in-place update for fused DynamicUpdateSlice. - - // Emit IR to read dynamic start indices from 'start_indices'. - const int64 rank = ShapeUtil::Rank(operand->shape()); - llvm_ir::IrArray::Index start_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index({ir_builder_.getInt64(i)}); - llvm_ir::IrArray start_indices_array(GetIrArrayForOp(start_indices)); - start_index[i] = - start_indices_array.EmitReadArrayElement(dim_index, &ir_builder_); - } - - // Create loop body emitter which emits code to do the following: - // *) Map requested 'index' and slice 'start_index' to input/output shape - // as 'output_index'. - // *) Reads value from 'update'. - // *) Writes value to input/output array at 'output_index'. - auto loop_body_emitter = - [&](const llvm_ir::IrArray::Index& index) -> Status { - // Calculate 'output_index' at which to write value from update. - llvm_ir::IrArray::Index output_index(rank); - for (int64 i = 0; i < rank; ++i) { - // Emit IR which computes: - // output_index = (start_index + index) % dim_size - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), operand->shape().dimensions(i)); - llvm::Value* start_index0 = ir_builder_.CreateZExtOrBitCast( - start_index[i], index[i]->getType()); - output_index[i] = ir_builder_.CreateURem( - ir_builder_.CreateAdd(start_index0, index[i]), dim_size); - } - - // Read value from 'update'. - llvm_ir::IrArray update_array(GetIrArrayForOp(update)); - llvm::Value* update_data = - update_array.EmitReadArrayElement(index, &ir_builder_); - - // Write value to output array. - GetIrArrayForOp(operand).EmitWriteArrayElement(output_index, update_data, - &ir_builder_); - return Status::OK(); - }; - - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(loop_body_emitter, update->shape(), &ir_builder_) - .EmitLoop(IrName(dynamic_update_slice, "in_place"))); - - TF_ASSIGN_OR_RETURN(llvm::Value * dynamic_update_slice_address, - EmitTargetAddressForOp(dynamic_update_slice)); - emitted_value_[dynamic_update_slice] = dynamic_update_slice_address; - return Status::OK(); + } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice, + assignment_)) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); + auto operands = GetIrArraysForOperandsOf(dynamic_update_slice); + return llvm_ir::EmitDynamicUpdateSliceInPlace( + operands, GetIrArrayFor(dynamic_update_slice), + IrName(dynamic_update_slice, "in_place"), &ir_builder_); } return DefaultAction(dynamic_update_slice); } @@ -2277,7 +2173,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); // Load an element from the operand. - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); @@ -2297,7 +2193,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { } // Store the operand element to the computed output location. - llvm_ir::IrArray output_array(GetIrArrayForOp(pad)); + llvm_ir::IrArray output_array(GetIrArrayFor(pad)); output_array.EmitWriteArrayElement(output_index, operand_data, &ir_builder_); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); @@ -2313,11 +2209,11 @@ static const HloInstruction* StripTranspose(const HloInstruction& hlo) { } Status IrEmitter::HandleFusion(HloInstruction* fusion) { + auto* root = fusion->fused_expression_root(); if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) { - const HloInstruction* dot = fusion->fused_expression_root(); - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(root->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1)); DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && rhs_parameter->opcode() == HloOpcode::kParameter); const HloInstruction* lhs = @@ -2326,18 +2222,15 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { fusion->operand(rhs_parameter->parameter_number()); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*dot, /*operands=*/{lhs, rhs}, + /*instruction=*/*root, /*operands=*/{lhs, rhs}, /*supported_types=*/{F32})); - llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs)); - llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs)); + llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); + llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); Shape target_shape = fusion->shape(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(fusion)); - llvm_ir::IrArray target_array(target_address, target_shape); - AddAliasingInformationToIrArray(*fusion, &target_array); - + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); + llvm_ir::IrArray target_array = GetIrArrayFor(fusion); VLOG(2) << "HandleFusion kTransposeDot: "; VLOG(2) << " lhs operand: " << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); @@ -2348,19 +2241,25 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { // Dot operation is complicated so we delegate to a helper class. TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, dot->operand(0)->IsRank2Transpose(), - dot->operand(1)->IsRank2Transpose(), target_array, lhs_array, rhs_array, - GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_)); - - emitted_value_[fusion] = target_address; + *root, root->operand(0)->IsRank2Transpose(), + root->operand(1)->IsRank2Transpose(), target_array, lhs_array, + rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, + hlo_module_config_)); return Status::OK(); + } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, + assignment_)) { + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); + + // Delegate to common implementation of fused in-place dynamic-update-slice. + auto operands = GetIrArraysForOperandsOf(fusion); + return llvm_ir::EmitFusedDynamicUpdateSliceInPlace( + fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, + &ir_builder_); } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { - std::vector parameter_arrays; - for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArrayForOp(operand)); - } CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + auto operands = GetIrArraysForOperandsOf(fusion); + FusedIrEmitter fused_emitter(operands, &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); @@ -2378,14 +2277,20 @@ Status IrEmitter::HandleCall(HloInstruction* call) { parameter_addresses.push_back(GetEmittedValueFor(operand)); } - TF_ASSIGN_OR_RETURN(llvm::Value * output_address, - EmitTargetAddressForOp(call)); - output_address->setName(AsStringRef(IrName(call))); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); - EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, - output_address, computation->name()); + if (!computation->root_instruction()->outer_dimension_partitions().empty() && + !parallel_cpu_backend_) { + // ParallelTaskAssignment assigned partitions, emit call to + // ParallelForkJoin. + TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses, + emitted_value_[call], computation, + call_ir_function)); + } else { + EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, + emitted_value_[call], computation->name()); + } - emitted_value_[call] = output_address; return Status::OK(); } @@ -2414,17 +2319,13 @@ Status IrEmitter::HandleCustomCall( /*Params=*/{i8_ptr_type, operands_alloca->getType()}, /*isVarArg=*/false))); - TF_ASSIGN_OR_RETURN(llvm::Value * output_address, - EmitTargetAddressForOp(custom_call)); - output_address->setName(AsStringRef(IrName(custom_call))); - - auto* output_address_arg = - ir_builder_.CreatePointerCast(output_address, i8_ptr_type); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); + auto* output_address_arg = ir_builder_.CreatePointerCast( + GetEmittedValueFor(custom_call), i8_ptr_type); ir_builder_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca}); - emitted_value_[custom_call] = output_address; return Status::OK(); } @@ -2568,10 +2469,8 @@ StatusOr IrEmitter::EmitFastConcatenate( llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); llvm::Type* i8_type = ir_builder_.getInt8Ty(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(concatenate)); - - llvm_ir::IrArray target_array(target_address, output_shape); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); + llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); llvm_ir::ForLoopNest loops(IrName(concatenate), &ir_builder_); llvm_ir::IrArray::Index outer_dims_index = @@ -2588,8 +2487,6 @@ StatusOr IrEmitter::EmitFastConcatenate( unsigned primitive_type_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - AddAliasingInformationToIrArray(*concatenate, &target_array); - // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. llvm::Value* target_region_begin = ir_builder_.CreateBitCast( @@ -2608,7 +2505,7 @@ StatusOr IrEmitter::EmitFastConcatenate( // equal to the product of inner dimensions. for (HloInstruction* operand : operands) { const Shape& input_shape = operand->shape(); - llvm_ir::IrArray source_array = GetIrArrayForOp(operand); + llvm_ir::IrArray source_array = GetIrArrayFor(operand); llvm::Value* copy_source_address = ir_builder_.CreateBitCast( source_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_, "src_addr"), @@ -2632,8 +2529,6 @@ StatusOr IrEmitter::EmitFastConcatenate( SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); } - emitted_value_[concatenate] = target_address; - return true; } @@ -2705,7 +2600,7 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { // For the parallel cpu backend, we record the total for each embedded // computation callee with its caller kCall HLO. HloInstruction* hlo_to_lookup = nullptr; - if (IsParallelContext()) { + if (parallel_cpu_backend_ && is_top_level_computation_) { auto* computation = root->parent(); auto* entry_computation = computation->parent()->entry_computation(); if (computation != entry_computation) { @@ -2833,7 +2728,7 @@ Status IrEmitter::Postprocess(HloInstruction* hlo) { return Status::OK(); } -llvm_ir::IrArray IrEmitter::GetIrArrayForOp(const HloInstruction* hlo) { +llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) { llvm::Value* value_for_op = GetEmittedValueFor(hlo); llvm_ir::IrArray array(value_for_op, hlo->shape()); @@ -2841,6 +2736,16 @@ llvm_ir::IrArray IrEmitter::GetIrArrayForOp(const HloInstruction* hlo) { return array; } +std::vector IrEmitter::GetIrArraysForOperandsOf( + const HloInstruction* hlo) { + std::vector arrays; + std::transform( + hlo->operands().begin(), hlo->operands().end(), + std::back_inserter(arrays), + [&](const HloInstruction* operand) { return GetIrArrayFor(operand); }); + return arrays; +} + llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) { auto it = emitted_value_.find(hlo); if (it == emitted_value_.end()) { @@ -2853,12 +2758,27 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { return llvm_ir::ShapeToIrType(shape, &ir_builder_); } +std::vector IrEmitter::GetComputeFunctionParams() { + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); + std::vector compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds_ > 0) { + compute_function_params.push_back(i64_ptr_type); + } + if (hlo_to_profile_idx_) { + compute_function_params.push_back(i64_ptr_type); + } + return compute_function_params; +} + llvm::Argument* IrEmitter::GetResultArgument() { return GetArg(compute_function_, 0); } llvm::Argument* IrEmitter::GetProfileCountersArgument() { - const int64 arg_index = IsParallelContext() ? 5 : 4; + const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; return hlo_to_profile_idx_ ? GetArg(compute_function_, arg_index) : nullptr; } @@ -2909,14 +2829,12 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( ir_builder_.CreateLoad(tempbuf_address_ptr); if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // Loading the address of a buffer is invariant of the point at which the - // load is executed in the program because we never reassign buffers. + // Loading the address of a buffer is invariant of the point at which the + // load is executed in the program because we never reassign buffers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); } - llvm_ir::SetTbaaForInstruction(tempbuf_address_base, target_shape, - /*is_pointer_to=*/true); AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size()); AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size()); @@ -2943,18 +2861,11 @@ llvm::Value* IrEmitter::EmitElementFunctionCall( AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); } -// Emits a core function call based on the following pseudo-code. -// -// char** parameter_addresses_buffer = -// allocate buffer with a pointer for each parameter to the function -// for each parameter index, i.e. for i = 0, ..., #parameters: -// parameter_addresses_buffer[i] = parameter_addresses[i] -// call function(return_value_buffer, -// parameter_addresses_buffer, -// temps) -// return return_value_buffer -- address of the return value. -void IrEmitter::EmitArrayFunctionCallInto( - llvm::Function* function, +// Emits code to allocate an array of parameter address pointers, and store +// each address from 'parameter_addresses'. +// Returns an array of compute function call arguments (including parameter +// address buffer). +std::vector IrEmitter::GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice parameter_addresses, llvm::Value* return_value_buffer, tensorflow::StringPiece name) { llvm::Value* parameter_addresses_buffer = @@ -2983,7 +2894,26 @@ void IrEmitter::EmitArrayFunctionCallInto( if (auto* profile_counters = GetProfileCountersArgument()) { arguments.push_back(profile_counters); } - ir_builder_.CreateCall(function, arguments); + return arguments; +} + +// Emits a core function call based on the following pseudo-code. +// +// char** parameter_addresses_buffer = +// allocate buffer with a pointer for each parameter to the function +// for each parameter index, i.e. for i = 0, ..., #parameters: +// parameter_addresses_buffer[i] = parameter_addresses[i] +// call function(return_value_buffer, +// parameter_addresses_buffer, +// temps) +// return return_value_buffer -- address of the return value. +void IrEmitter::EmitArrayFunctionCallInto( + llvm::Function* function, + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name) { + ir_builder_.CreateCall( + function, GetArrayFunctionCallArguments(parameter_addresses, + return_value_buffer, name)); } llvm::Value* IrEmitter::EmitArrayFunctionCall( @@ -3003,10 +2933,114 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( return return_value_buffer; } -StatusOr IrEmitter::EmitTargetAddressForOp( - const HloInstruction* op, const ShapeIndex& shape_index) { - const Shape& target_shape = ShapeUtil::GetSubshape(op->shape(), shape_index); - if (op == op->parent()->root_instruction() && shape_index.empty()) { +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status IrEmitter::EmitParallelForkJoin( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* output_address, HloComputation* computation, + llvm::Function* parallel_function) { + HloInstruction* root = computation->root_instruction(); + + // Build ParallelForkJoin function type. + std::vector compute_function_params = GetComputeFunctionParams(); + // Number of parallel compute functions. + compute_function_params.push_back(ir_builder_.getInt32Ty()); + // Array of partitions. There is an array element for each + // partition x partition_dim x 2 (for dimension start and limit). + compute_function_params.push_back( + llvm::Type::getInt64PtrTy(module_->getContext())); + // Number of partitioned most-major dimensions in 'root.shape'. + compute_function_params.push_back(ir_builder_.getInt32Ty()); + // Function pointer for compute function to be dispatched in parallel. + compute_function_params.push_back( + llvm::Type::getInt8PtrTy(module_->getContext())); + + llvm::FunctionType* fork_join_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(module_->getContext()), + /*Params=*/compute_function_params, + /*isVarArg=*/false); + + llvm::Function* fork_join_func = + llvm::cast(module_->getOrInsertFunction( + runtime::kParallelForkJoinSymbolName, fork_join_type)); + fork_join_func->setCallingConv(llvm::CallingConv::C); + fork_join_func->setDoesNotThrow(); + + // Add common compute function arguments. + const string name = computation->name(); + std::vector arguments = + GetArrayFunctionCallArguments(parameter_addresses, output_address, name); + + // Create ShapePartitionIterator to generate all partitions of 'root.shape'. + ShapePartitionIterator partition_iterator(root->shape(), + root->outer_dimension_partitions()); + const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); + // Add argument specifying the number of parallel partitions. + arguments.push_back(ir_builder_.getInt32(num_partitions)); + + // The number of partitioned most-major dimensions in 'root.shape'. + const int32 num_partitioned_dims = root->outer_dimension_partitions().size(); + // A dimension partition consists of two elements: [start_index, limit_index). + const int32 dim_partition_size = 2; + // Calculate array partition stride. + const int32 array_partition_stride = + num_partitioned_dims * dim_partition_size; + // Calculate the total number of elements in the partition array. + const int32 partition_array_size = + dim_partition_size * num_partitioned_dims * num_partitions; + + // Store dimension partition values as llvm constants in 'partitions'. + // See comments in runtime_fork_join.cc for array layout description. + std::vector partitions(partition_array_size); + for (int32 i = 0; i < num_partitions; ++i) { + std::vector> dim_partitions = + partition_iterator.GetPartition(i); + CHECK_EQ(num_partitioned_dims, dim_partitions.size()); + const int32 partition_index = i * array_partition_stride; + for (int32 j = 0; j < num_partitioned_dims; ++j) { + const std::pair& dim_partition = dim_partitions[j]; + const int32 index = partition_index + j * dim_partition_size; + // Store partition [dim_start, dim_limit) intervals for each dimension. + partitions[index] = ir_builder_.getInt64(dim_partition.first); + partitions[index + 1] = + ir_builder_.getInt64(dim_partition.first + dim_partition.second); + } + } + + // Create global variable out of dimension partitions in 'partitions'. + llvm::ArrayType* partitions_array_type = + llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size); + llvm::Constant* partitions_array = + llvm::ConstantArray::get(partitions_array_type, partitions); + llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/partitions_array_type, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/partitions_array, + /*Name=*/ + AsStringRef( + tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + + // Add argument specifying parallel dimension partitions. + arguments.push_back(ir_builder_.CreateBitCast( + global_partitions_array, + llvm::Type::getInt64PtrTy(module_->getContext()))); + // Add argument specifying the number of partitioned most-major dimensions. + arguments.push_back(ir_builder_.getInt32(num_partitioned_dims)); + // Add argument for parallel compute function pointer. + arguments.push_back( + ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy())); + // Emit call to parallel fork/join. + ir_builder_.CreateCall(fork_join_func, arguments); + + return Status::OK(); +} + +Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { + llvm::Value* addr; + const Shape& target_shape = op->shape(); + if (op == op->parent()->root_instruction()) { // For the root node, we write directly to the output buffer of the // function. llvm::Argument* retval = GetResultArgument(); @@ -3016,15 +3050,18 @@ StatusOr IrEmitter::EmitTargetAddressForOp( attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); retval->addAttrs(attr_builder); } - return ir_builder_.CreateBitCast(retval, + addr = ir_builder_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo()); - } - - // For other nodes, we need the temporary buffer allocated for this node to - // write the result into. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(op)); - return EmitTempBufferPointer(slice, target_shape); + } else { + // For other nodes, we need the temporary buffer allocated for this node to + // write the result into. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + addr = EmitTempBufferPointer(slice, target_shape); + } + addr->setName(AsStringRef(IrName(op))); + emitted_value_[op] = addr; + return Status::OK(); } Status IrEmitter::EmitTargetElementLoop( @@ -3038,12 +3075,9 @@ Status IrEmitter::EmitTargetElementLoop( const llvm_ir::ElementGenerator& element_generator) { VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); - // target_address will hold the address of the target buffer we will write the - // result of the computation into. const Shape& target_shape = target_op->shape(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(target_op)); - VLOG(2) << " target address: " << llvm_ir::DumpToString(*target_address); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op)); + llvm_ir::IrArray target_array = GetIrArrayFor(target_op); if (target_op->IsMultiOutputFusion()) { // For multiple outputs fusion, we need to emit each operand and the root. @@ -3066,13 +3100,9 @@ Status IrEmitter::EmitTargetElementLoop( for (int64 i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } - llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, target_shape), - tuple_operand_ptrs, &ir_builder_); + llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_); } else { - llvm_ir::IrArray target_array(target_address, target_shape); - AddAliasingInformationToIrArray(*target_op, &target_array); - if (ShouldEmitParallelLoopFor(*target_op)) { TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( target_shape, element_generator, IrName(target_op), &target_array)); @@ -3082,8 +3112,6 @@ Status IrEmitter::EmitTargetElementLoop( .EmitLoop(IrName(target_op))); } } - - emitted_value_[target_op] = target_address; return Status::OK(); } @@ -3175,7 +3203,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArrayFor(operand).EmitReadArrayElement(index, &ir_builder_); }; } CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 8042e03e69561aeacccc5498eaf52f32bbd78b62..58c185af1ec6e00c854b39f77281da7760855fe0 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/IR/Value.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -104,11 +105,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { // llvm_module: the LLVM module to emit IR into. // hlo_to_profile_idx: the mapping from HLO to its index in the profiling // array. + // external_constant_pool: if non-null, points to an ExternalConstantPool + // instance into which the Ir emitter can spill + // constants. IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, const std::unordered_map* hlo_to_profile_idx, - llvm::TargetMachine* target_machine); + llvm::TargetMachine* target_machine, + ExternalConstantPool* external_constant_pool); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -146,7 +151,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { // // Default action which emits code for most operations. Operations which are // special in some way are handled explicitly in HandleFoo methods. - Status DefaultAction(HloInstruction* hlo_instruction) override; + Status DefaultAction(HloInstruction* hlo) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleConstant(HloInstruction* constant, @@ -175,7 +180,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleReduceWindow(HloInstruction* reduce_window, HloInstruction* operand, const Window& window, HloComputation* function) override; - Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override; Status HandleSend(HloInstruction* send) override; Status HandleSlice(HloInstruction* slice, HloInstruction* /*operand*/) override; @@ -208,7 +213,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; - Status Postprocess(HloInstruction* visited) override; + Status Postprocess(HloInstruction* hlo) override; private: // Private helper to initialize an IR function for the computation. @@ -220,8 +225,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Gets the IR Value emitted previously for the given hlo. // - // Prefer calling GetIrArrayForOp if the value you're reading is a buffer, - // because GetIrArrayForOp annotates buffer's loads/stores with noalias + // Prefer calling GetIrArrayFor if the value you're reading is a buffer, + // because GetIrArrayFor annotates buffer's loads/stores with noalias // metadata. // // Make sure to call this only when you're certain a value *was* emitted - if @@ -229,7 +234,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* GetEmittedValueFor(const HloInstruction* hlo); // Gets an IrArray representing the given hlo. - llvm_ir::IrArray GetIrArrayForOp(const HloInstruction* hlo); + llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo); + + // Gets a list of IrArrays, one for each of hlo's operands. + std::vector GetIrArraysForOperandsOf( + const HloInstruction* hlo); // Augments IrArray with aliasing information. void AddAliasingInformationToIrArray(const HloInstruction& hlo, @@ -240,6 +249,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to get the IR type matching the given shape. llvm::Type* IrShapeType(const Shape& shape); + // Returns an array of compute function parameter types. + std::vector GetComputeFunctionParams(); + // Get the llvm::Value* that represents the "retval" argument of the // computation function being emitted by this emitter. llvm::Argument* GetResultArgument(); @@ -304,7 +316,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { void EmitArrayFunctionCallInto( llvm::Function* function, tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value, tensorflow::StringPiece name); + llvm::Value* return_value_buffer, tensorflow::StringPiece name); // Array function call emitter. Returns a Value for the function's return // value buffer address. The return value buffer is alloca'ed by this @@ -314,6 +326,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name); + // Returns an array of compute function call arguments. + std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name); + + // Emits a call to a runtime fork/join function which dispatches parallel + // calls to 'parallel_function' (and joins threads before returning). + Status EmitParallelForkJoin( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* output_address, HloComputation* computation, + llvm::Function* parallel_function); + // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. Status ElementTypesSameAndSupported( @@ -353,11 +377,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitMemcpy(const HloInstruction& source, const HloInstruction& destination); - // Emit IR to compute the target address of the buffer for the given op. - // The returned Value is a pointer to a IR type that represents the op's - // element type. - StatusOr EmitTargetAddressForOp( - const HloInstruction* op, const ShapeIndex& shape_index = {}); + // Emits IR to compute the target address of the buffer for the given op. + // After calling this function, you can get a pointer to this buffer by + // calling GetIrArrayForOp or GetEmittedValueFor. + Status EmitTargetAddressForOp(const HloInstruction* op); // Structurizes "array_elements" into an MD array that represents "shape". // This is a recursive function, and "dimension_index" indicates the index of @@ -447,10 +470,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array); - // Name of the computation entry function. This function serves as the - // top-level "main" of the computation and will be invoked by the JIT. - string entry_function_name_; - // Assignment of the temporary buffers needed by the computation and their // shape information. const BufferAssignment& assignment_; @@ -592,12 +611,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); - // Returns true if the current function being emitted is called in a - // parallel context (returns false otherwise). - bool IsParallelContext() { - return parallel_cpu_backend_ && is_top_level_computation_; - } - const HloModuleConfig& hlo_module_config_; const bool parallel_cpu_backend_; @@ -606,6 +619,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { TargetMachineFeatures target_machine_features_; + int64 external_global_constant_counter_ = 0; + ExternalConstantPool* external_constant_pool_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc new file mode 100644 index 0000000000000000000000000000000000000000..5afb2e67fff639e9cabb3740c5240e1ca90b5644 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -0,0 +1,249 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" + +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +namespace xla { +namespace cpu { + +class SimpleCostModel : public ParallelCostModel { + public: + SimpleCostModel(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size) + : max_parallelism_(max_parallelism), shape_size_(shape_size) {} + ~SimpleCostModel() override {} + + int64 GetParallelTaskCount(HloInstruction* instruction) override { + // Simple cost model based on hlo size and typical L2 cache size. + const int64 instruction_cost = shape_size_(instruction->shape()); + const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. + // Return target parallel task count in [1, max_parallelism_]. + return std::min(max_parallelism_, + std::max(1LL, instruction_cost / min_cost_per_thread)); + } + + private: + const int64 max_parallelism_; + const HloCostAnalysis::ShapeSizeFunction shape_size_; +}; + +class DefaultCostModel : public ParallelCostModel { + public: + DefaultCostModel(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + std::unique_ptr cost_analysis) + : max_parallelism_(max_parallelism), + shape_size_(shape_size), + cost_analysis_(std::move(cost_analysis)) {} + ~DefaultCostModel() override {} + + int64 GetParallelTaskCount(HloInstruction* instruction) override { + // Parameters for parallel task count computation. + int64 instruction_cost; + int64 min_cost_per_thread; + int64 max_parallelism; + // Calculate flops-to-bytes-ratio for 'instruction'. + const int64 bytes_accessed = + std::max(1LL, cost_analysis_->bytes_accessed(*instruction)); + const float flops_to_bytes_ratio = + cost_analysis_->flop_count(*instruction) / + static_cast(bytes_accessed); + // Check for I/O bound instructions. + if (flops_to_bytes_ratio <= 1.0) { + // Limit max parallelism for I/O bound instructions by assuming a + // sub-linear scaling function (fit based on empirical benchmark results). + // TODO(29630486) Develop system bandwidth model. + max_parallelism = + std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs())); + // Use shape size instruction cost and L2 cache size min per-thread cost. + instruction_cost = shape_size_(instruction->shape()); + min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. + } else { + // Use max parallelism for compute bound instructions. + max_parallelism = max_parallelism_; + // Calculate the instruction cost in cycles. + // TODO(29630486) Improve on this linear cost model. + // Consider making 'min_cost_per_thread' be a function of the target + // bandwidth limit for instructions with low arithmetic complexity. + instruction_cost = + 1 * cost_analysis_->flop_count(*instruction) + + 2 * cost_analysis_->transcendental_count(*instruction) + + 10 * cost_analysis_->bytes_accessed(*instruction); + // Minimum per-thread cost is 100us of work on a 2GHz core. + min_cost_per_thread = 100000; + } + // Return target parallel task count in [1, max_parallelism_]. + return std::min(max_parallelism, + std::max(1LL, instruction_cost / min_cost_per_thread)); + } + + private: + const int64 max_parallelism_; + const HloCostAnalysis::ShapeSizeFunction shape_size_; + const std::unique_ptr cost_analysis_; +}; + + +ParallelTaskAssignment::ParallelTaskAssignment( + const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + HloModule* module) { + VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; + // Run cost analysis on 'module'. + auto cost_analysis = MakeUnique(shape_size); + HloComputation* computation = module->entry_computation(); + Status status = computation->root_instruction()->Accept(cost_analysis.get()); + if (status.ok()) { + // Set default cost model based on 'cost_analysis'. + cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size, + std::move(cost_analysis))); + } else { + // Fall back to a simple cost model based on hlo size and L2 cache size. + // Note that HloCostAnalysis can returns an error status (likely because + // HLOs like CustomCall are not yet implemented in the HloCostAnalysis). + cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size)); + } +} + +int64 ParallelTaskAssignment::GetTargetParallelTaskCount( + HloInstruction* instruction) { + // Currently, we do not assign parallel tasks to instructions with at least + // one of the following properties: + // *) Internal threading (library calls to kConv, kDot, and kCustomCall). + // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Tuple-shaped. + // TODO(b/27458679) Parallelize instructions which are skipped here. + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant || + instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kCustomCall || + instruction->opcode() == HloOpcode::kSelectAndScatter || + (instruction->opcode() == HloOpcode::kConvolution && + PotentiallyImplementedAsEigenConvolution(*instruction)) || + PotentiallyImplementedAsEigenDot(*instruction) || + (instruction->opcode() == HloOpcode::kFusion && + instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || + ShapeUtil::IsTuple(instruction->shape())) { + return 1; + } + // Consult 'cost_model_' to compute target parallel task count. + return cost_model_->GetParallelTaskCount(instruction); +} + +StatusOr ParallelTaskAssigner::Run(HloModule* module) { + XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY"); + XLA_VLOG_LINES(3, module->ToString()); + + // Compute target parallel task counts for all instructions in 'module'. + HloToParallelTasks hlo_to_parallel_tasks; + ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks); + + // Assign parallel tasks to target specific instructions in 'module'. + // TODO(b/27458679) Support inter-op parallelism. + bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks); + + XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT"); + XLA_VLOG_LINES(3, module->ToString()); + return changed; +} + +bool ParallelTaskAssigner::AssignParallelTasks( + HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) { + return AssignParallelTasksHelper(module, module->entry_computation(), + hlo_to_parallel_tasks); +} + +bool ParallelTaskAssigner::AssignParallelTasksHelper( + HloModule* module, HloComputation* computation, + const HloToParallelTasks& hlo_to_parallel_tasks) { + bool changed = false; + // Snapshot set of instructions because outlining modifies the set below. + std::vector instructions(computation->instructions().begin(), + computation->instructions().end()); + for (auto* instruction : instructions) { + // Assign parallel tasks to sub-computations for While and Call HLOs. + // TODO(b/27458679) Evaluate alternative intra-op parallelsim placement, + // and support other callable computations like reduce. + if (instruction->opcode() == HloOpcode::kWhile) { + changed |= AssignParallelTasksHelper(module, instruction->while_body(), + hlo_to_parallel_tasks); + continue; + } else if (instruction->opcode() == HloOpcode::kCall) { + changed |= AssignParallelTasksHelper(module, instruction->to_apply(), + hlo_to_parallel_tasks); + continue; + } + // Skip if no parallel tasks were computed in first pass. + auto it = hlo_to_parallel_tasks.find(instruction); + if (it == hlo_to_parallel_tasks.end()) { + continue; + } + // Get target parallel task count computed for 'instruction'. + const int64 target_parallel_task_count = (*it).second; + // Assign feasible dimension partitions (based on actual dimension sizes). + auto dim_partition_counts = ShapePartitionAssigner(instruction->shape()) + .Run(target_parallel_task_count); + const int64 total_partition_count = + ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts); + if (total_partition_count <= 1) { + // Feasible partition calculation resulting in no partitioning, so skip. + continue; + } + + // Outline 'instruction' in 'computation' for parallel task assignment. + auto* call = module->OutlineExpressionFromComputation( + {instruction}, + tensorflow::strings::StrCat("parallel_", instruction->name()), + computation); + + // Set assigned dimension partitioning to 'instruction'. + auto* new_root = call->to_apply()->root_instruction(); + new_root->set_outer_dimension_partitions(dim_partition_counts); + + VLOG(2) << "Assigned parallel task count: " << total_partition_count + << " to instruction: " << new_root->name() + << " parent: " << new_root->parent()->name(); + changed = true; + } + return changed; +} + +void ParallelTaskAssigner::ComputeTargetParallelTasks( + HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { + // Compute parallel task counts for all instructions in 'module'. + for (auto* computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + for (auto* instruction : computation->instructions()) { + // Query ParallelTaskAssignment for target parallel task count. + const int64 target_parallel_task_count = + parallel_task_assignment_.GetTargetParallelTaskCount(instruction); + if (target_parallel_task_count > 1) { + hlo_to_parallel_tasks->insert( + {instruction, target_parallel_task_count}); + } + } + } +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h new file mode 100644 index 0000000000000000000000000000000000000000..e036da5784f6151eb3b01107ec7f3ab820071a60 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ + +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace cpu { + +// Simple interface for different parallel cost model implementations. +class ParallelCostModel { + public: + virtual ~ParallelCostModel() = default; + virtual int64 GetParallelTaskCount(HloInstruction* instruction) = 0; +}; + +// ParallelTaskAssignment computes parallel task counts for HLOs in 'module'. +class ParallelTaskAssignment { + public: + // 'max_parallelism': the maximum parallel task count per instruction. + // 'shape_size': shape size function used by HloCostAnalysis during parallel + // task assignment. + // 'module': the containing HloModule. + ParallelTaskAssignment( + const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + HloModule* module); + ~ParallelTaskAssignment() {} + + // Computes and returns the target parallel task count for 'instruction'. + int64 GetTargetParallelTaskCount(HloInstruction* instruction); + + private: + std::unique_ptr cost_model_; +}; + +// ParallelTaskAssigner computes target parallel task counts for all HLOs +// in the module, then assigns parallel task counts to HLOs in the entry +// computation, or to HLOs in embedded computations invoked by (potentially +// nested) kWhile or kCall instructions. +// Each HLO which is assigned parallel task counts is outlined into its +// own embedded computation, which is compiled as a parallel compute function, +// and which is invoked from a kCall instruction that is lowered in codegen to +// a runtime parallel fork/join call. +class ParallelTaskAssigner : public HloPassInterface { + public: + // 'max_parallelism': the maximum parallel task count per instruction. + // 'shape_size': shape size function used by HloCostAnalysis during parallel + // task assignment. + // 'module': the containing HloModule. + ParallelTaskAssigner(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + HloModule* module) + : parallel_task_assignment_(max_parallelism, shape_size, module) {} + ~ParallelTaskAssigner() override {} + + tensorflow::StringPiece name() const override { + return "cpu-parallel-task-assigner"; + } + + // Run parallel task assigner on 'module'. + // Returns true if the computation was changed, false otherwise. + StatusOr Run(HloModule* module) override; + + private: + using HloToParallelTasks = std::unordered_map; + + // Assigns target parallel tasks from 'hlo_to_parallel_tasks' to HLOs in + // 'module'. + // Returns true if the computation was changed, false otherwise. + bool AssignParallelTasks(HloModule* module, + const HloToParallelTasks& hlo_to_parallel_tasks); + bool AssignParallelTasksHelper( + HloModule* module, HloComputation* computation, + const HloToParallelTasks& hlo_to_parallel_tasks); + + // Computes target parallel task counts (returned in 'parallel_task_counts') + // for parallelizable instructions in 'module'. + void ComputeTargetParallelTasks(HloModule* module, + HloToParallelTasks* hlo_to_parallel_tasks); + + ParallelTaskAssignment parallel_task_assignment_; +}; + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc new file mode 100644 index 0000000000000000000000000000000000000000..d03da46575b331de113cc5f33c2b4267504e8308 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; +using tensorflow::uint64; + +using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, + int64*, uint64*); + +// Dispatches 'num_partitions - 1' calls to 'function_ptr' in parallel. +// Calls 'function_ptr' for first partition inline. +// Uses blocking counter to synchonize threads after parallel calls complete. +// +// The 'partitions' array has a total number of elements equal to +// 'num_partitions * num_partitioned_dims * 2' (the '2' is necessary to specify +// dimension start and limit indices). +// +// The 'partitions' array layout stores array elements in memory with dimension +// start limit as the most-minor dimension, followed by dimension, then +// partition. +// +// EX: Layout of 'partitions' array with 'num_partitions = 2', and +// 'num_partitioned_dims = 3' +// +// [partition0_dim0_start] +// [partition0_dim0_limit] +// [partition0_dim1_start] +// [partition0_dim1_limit] +// [partition0_dim2_start] +// [partition0_dim2_limit] +// [partition1_dim0_start] +// [partition1_dim0_limit] +// [partition1_dim1_start] +// [partition1_dim1_limit] +// [partition1_dim2_start] +// [partition1_dim2_limit] +// +void __xla_cpu_runtime_ParallelForkJoin( + void* result_ptr, const void* run_options_ptr, const void** params, + void** temps, uint64* prof_counters, int32 num_partitions, + int64* partitions, int32 num_partitioned_dims, void* function_ptr) { + VLOG(2) << "ParallelForkJoin ENTRY" + << " num_partitions: " << num_partitions + << " num_partitioned_dims: " << num_partitioned_dims; + CHECK_GT(num_partitions, 1); + CHECK_GT(num_partitioned_dims, 0); + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + ComputeFunctionType function = + reinterpret_cast(function_ptr); + // Compute partition stride in 'partitions' array. + const int64 stride = 2 * num_partitioned_dims; + + // Dispatch 'num_partitions - 1' compute functions to run in parallel. + tensorflow::BlockingCounter bc(num_partitions - 1); + for (int32 i = 1; i < num_partitions; ++i) { + const int64 offset = i * stride; + run_options->intra_op_thread_pool()->enqueueNoNotification( + [i, function, result_ptr, run_options_ptr, params, temps, prof_counters, + partitions, offset, &bc]() { + function(result_ptr, run_options_ptr, params, temps, + &partitions[offset], prof_counters); + bc.DecrementCount(); + VLOG(3) << "ParallelForkJoin partition " << i << " done."; + }); + } + + // Call first compute function inline. + function(result_ptr, run_options_ptr, params, temps, &partitions[0], + prof_counters); + VLOG(3) << "ParallelForkJoin partition 0 done."; + bc.Wait(); + VLOG(2) << "ParallelForkJoin EXIT"; +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h new file mode 100644 index 0000000000000000000000000000000000000000..fcf1cc62078d3847435a2e75e3ca9d109cf8b200 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// Dispatches 'num_partitions' parallel calls to 'function_ptr' and joins +// threads before returning. See comments in runtime_fork_join.cc for details. +extern void __xla_cpu_runtime_ParallelForkJoin( + void* result_ptr, const void* run_options_ptr, const void** params, + void** temps, tensorflow::uint64* prof_counters, + tensorflow::int32 num_partitions, tensorflow::int64* partitions, + tensorflow::int32 num_partitioned_dims, void* function_ptr); + +} // extern "C" + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index c3c11df090e88c3c24104b66d28b3b16f03baa80..cfffb3fbc30349e41e9053bf7982507cd6ed1052 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -104,6 +105,7 @@ class JITSymbolTable { ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedConvF32); ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF32); ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF64); + ADD_JIT_SYMBOL_TO_TABLE(ParallelForkJoin); #undef ADD_JIT_SYMBOL_TO_TABLE } @@ -117,8 +119,20 @@ const JITSymbolTable& GetJITSymbolTable() { } // A simple SymbolResolver that delegates to the host dynamic linker. -struct SimpleResolver : public llvm::JITSymbolResolver { +class SimpleResolver : public llvm::JITSymbolResolver { + public: + explicit SimpleResolver(ExternalConstantPool* external_constant_pool) + : external_constant_pool_(external_constant_pool) {} + llvm::JITSymbol findSymbol(const std::string& name) override { + string name_as_string(name); + if (const uint8* from_constant_pool = + external_constant_pool_->Find(string(name))) { + return llvm::JITEvaluatedSymbol( + reinterpret_cast(from_constant_pool), + llvm::JITSymbolFlags::None); + } + std::string canonical_name = CanonicalizeSymbol(name); const JITSymbolTable& jit_symbol_table = GetJITSymbolTable(); @@ -136,6 +150,9 @@ struct SimpleResolver : public llvm::JITSymbolResolver { llvm::JITSymbol findSymbolInLogicalDylib(const std::string& name) override { return nullptr; } + + private: + ExternalConstantPool* external_constant_pool_; }; llvm::SmallVector DetectMachineAttributes() { @@ -205,7 +222,7 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule( std::unique_ptr module) { auto handle = cantFail(compile_layer_.addModule( - std::move(module), MakeUnique())); + std::move(module), MakeUnique(external_constant_pool()))); module_handles_.push_back(handle); return handle; } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index e476c0e3812cc0fb2a2d633832374b3165ca072a..ded01e9e4d7442296f7406dd035e6ab385458238 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -90,6 +91,10 @@ class SimpleOrcJIT { llvm::TargetMachine* target_machine() const { return target_machine_.get(); } + ExternalConstantPool* external_constant_pool() { + return &external_constant_pool_; + } + private: std::vector module_handles_; std::unique_ptr target_machine_; @@ -97,6 +102,7 @@ class SimpleOrcJIT { const llvm::DataLayout data_layout_; ObjLayerT object_layer_; CompileLayerT compile_layer_; + ExternalConstantPool external_constant_pool_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 2c16a1b9033f45742f80b91eb1695315bd13ed80..5b1dbf439c7d3b02625e9d846a068b2262ceeeed 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -156,18 +156,32 @@ class DfsHloVisitor { HloInstruction* operand) { return HandleElementwiseUnary(is_finite); } - virtual Status HandleLogicalAnd(HloInstruction* logical_and, - HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_and); + virtual Status HandleAnd(HloInstruction* and_, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(and_); + } + virtual Status HandleNot(HloInstruction* not_, HloInstruction* operand) { + return HandleElementwiseUnary(not_); } - virtual Status HandleLogicalNot(HloInstruction* logical_not, - HloInstruction* operand) { - return HandleElementwiseUnary(logical_not); + virtual Status HandleOr(HloInstruction* or_, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(or_); } - virtual Status HandleLogicalOr(HloInstruction* logical_or, + virtual Status HandleShiftLeft(HloInstruction* shift_left, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_or); + return HandleElementwiseBinary(shift_left); + } + virtual Status HandleShiftRightArithmetic( + HloInstruction* shift_right_arithmetic, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(shift_right_arithmetic); } + virtual Status HandleShiftRightLogical(HloInstruction* shift_right_logical, + HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(shift_right_logical); + } + virtual Status HandleReducePrecision(HloInstruction* reduce_precision) { return HandleElementwiseUnary(reduce_precision); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 7117ecb08b2c1f83f155ca3d25d2831b5e411bc6..44f709bedec7ef0e50b830e8901796985ee7224e 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -126,14 +126,21 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } case HloOpcode::kNegate: return ir_builder_->CreateNeg(operand_value); - case HloOpcode::kLogicalNot: - // It is not sufficient to just call CreateNot() here because a PRED is - // represented as an i8 and the truth value is stored only in the bottom - // bit. - return ir_builder_->CreateZExt( - ir_builder_->CreateNot(ir_builder_->CreateTrunc( - operand_value, ir_builder_->getInt1Ty())), - llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_)); + case HloOpcode::kNot: { + auto type = op->shape().element_type(); + if (type == PRED) { + // It is not sufficient to just call CreateNot() here because a PRED + // is represented as an i8 and the truth value is stored only in the + // bottom bit. + return ir_builder_->CreateZExt( + ir_builder_->CreateNot(ir_builder_->CreateTrunc( + operand_value, ir_builder_->getInt1Ty())), + llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_)); + } else if (primitive_util::IsIntegralType(type)) { + return ir_builder_->CreateNot(operand_value); + } + return Unimplemented("unary op Not is not defined for type '%d'", type); + } default: return Unimplemented("unary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -557,10 +564,16 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, lhs_value, rhs_value), lhs_value, rhs_value); - case HloOpcode::kLogicalAnd: + case HloOpcode::kAnd: return ir_builder_->CreateAnd(lhs_value, rhs_value); - case HloOpcode::kLogicalOr: + case HloOpcode::kOr: return ir_builder_->CreateOr(lhs_value, rhs_value); + case HloOpcode::kShiftLeft: + return ir_builder_->CreateShl(lhs_value, rhs_value); + case HloOpcode::kShiftRightArithmetic: + return ir_builder_->CreateAShr(lhs_value, rhs_value); + case HloOpcode::kShiftRightLogical: + return ir_builder_->CreateLShr(lhs_value, rhs_value); default: return Unimplemented("binary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -799,7 +812,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kTanh: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, @@ -821,8 +834,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { const HloInstruction* lhs = hlo->operand(0); @@ -879,17 +895,31 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const int64 concat_dim = hlo->dimensions(0); auto source_index = target_index; + llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); + + // A terminator should be present iff we're emitting code + // into the middle (as opposed to the end) of a basic block. + CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), + init_block->getTerminator() == nullptr); + + llvm::BasicBlock* exit_block; + if (ir_builder_->GetInsertPoint() == init_block->end()) { + exit_block = llvm_ir::CreateBasicBlock( + /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); + } else { + exit_block = init_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge"))); + init_block->getTerminator()->eraseFromParent(); + } + + llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); llvm::PHINode* output = ir_builder_->CreatePHI( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), ir_builder_), hlo->operands().size()); - llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); auto prior_insert_point = ir_builder_->GetInsertPoint(); - llvm::BasicBlock* exit_block = - init_block->splitBasicBlock(output, "concat_merge"); ir_builder_->SetInsertPoint(init_block); - init_block->getTerminator()->eraseFromParent(); for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 79fedb61c971862fc0e3a59e01e55825f09c587d..62b8fa6a2b77e21ae3aa257935f5a22e3e8a130b 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -69,15 +69,6 @@ Status Executable::DumpSessionModule() { *session_module_); } -// Removes illegal characters from filenames. -static void SanitizeFilename(string* name) { - for (char& c : *name) { - if (c == '/' || c == '\\' || c == '[' || c == ']') { - c = '_'; - } - } -} - /* static */ Status Executable::DumpToDirectory( const string& directory_path, string filename, const SessionModule& session_module) { @@ -89,7 +80,7 @@ static void SanitizeFilename(string* name) { // "directory already exists" error. TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path)); } - SanitizeFilename(&filename); + filename = SanitizeFileName(std::move(filename)); string file_path = tensorflow::io::JoinPath(directory_path, filename); return tensorflow::WriteBinaryProto(env, file_path, session_module); } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 82c32407d3d10dffc89c448ee402ea3a789a0c39..de84e06cebab72d272bd888f280f5e5b221b97d1 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -104,7 +104,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "@llvm//:core", ], @@ -147,6 +147,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc index 7cf5613ce571cadc5ad45ade17341376f3d6ae39..5aaf072f9d2c95e2fff70a1c5337432a12a1aa48 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -72,8 +72,10 @@ MatchBackwardFilter(HloInstruction* conv) { // Step 2: match paddings and dimension numbers of the forward convolution. const ConvolutionDimensionNumbers& conv_dnums = conv->convolution_dimension_numbers(); - auto batch_dim = conv_dnums.batch_dimension(); - auto feature_dim = conv_dnums.feature_dimension(); + auto input_batch_dim = conv_dnums.input_batch_dimension(); + auto input_feature_dim = conv_dnums.input_feature_dimension(); + auto output_batch_dim = conv_dnums.output_batch_dimension(); + auto output_feature_dim = conv_dnums.output_feature_dimension(); auto spatial_dims = conv_dnums.spatial_dimensions(); for (const WindowDimension& window_dim : conv->window().dimensions()) { @@ -183,8 +185,10 @@ MatchBackwardFilter(HloInstruction* conv) { // convolution. The two activation dimensions are reversed (batch and // feature). ConvolutionDimensionNumbers backward_conv_dnums; - backward_conv_dnums.set_batch_dimension(feature_dim); - backward_conv_dnums.set_feature_dimension(batch_dim); + backward_conv_dnums.set_input_batch_dimension(input_feature_dim); + backward_conv_dnums.set_input_feature_dimension(input_batch_dim); + backward_conv_dnums.set_output_batch_dimension(output_feature_dim); + backward_conv_dnums.set_output_feature_dimension(output_batch_dim); for (int i = 0; i < spatial_dims.size(); ++i) { backward_conv_dnums.add_spatial_dimensions(spatial_dims[i]); } @@ -198,9 +202,9 @@ MatchBackwardFilter(HloInstruction* conv) { // the dimension numbering of the weight gradients. This transposition maps // dimension i to PositionInContainer(transpose->dimensions(), i). backward_conv_dnums.set_kernel_input_feature_dimension( - PositionInContainer(transpose->dimensions(), batch_dim)); + PositionInContainer(transpose->dimensions(), output_batch_dim)); backward_conv_dnums.set_kernel_output_feature_dimension( - PositionInContainer(transpose->dimensions(), feature_dim)); + PositionInContainer(transpose->dimensions(), output_feature_dim)); for (int i = 0; i < spatial_dims.size(); ++i) { backward_conv_dnums.add_kernel_spatial_dimensions( PositionInContainer(transpose->dimensions(), spatial_dims[i])); @@ -275,7 +279,7 @@ MatchBackwardInput(HloInstruction* conv) { Window new_window = old_window; for (size_t i = 0; i < spatial_dims.size(); ++i) { // Restore backward convolution's padding config from the matched pattern. - // See the comment in tensorflow/core/kernels/conv_grad_ops.cc + // See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc // for how we convert backward input convolution to a variant of forward // convolution. // diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc index 6699c8f3c4acd76ed58cccf314ca0ae1502d51d7..19b122ba0603b4ec08d73e05da4c2ae11a760553 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc @@ -45,8 +45,10 @@ class ConvolutionFoldingTest : public HloTestBase { // dimension in gradients as the input feature dimension in the filter. // // TODO(jingyue): Add more tests on NCHW input order which TF also supports. - tf_default_dnums_for_backward_filter_.set_batch_dimension(3); - tf_default_dnums_for_backward_filter_.set_feature_dimension(0); + tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); + tf_default_dnums_for_backward_filter_.set_output_batch_dimension(3); + tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); + tf_default_dnums_for_backward_filter_.set_output_feature_dimension(0); tf_default_dnums_for_backward_filter_.add_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_spatial_dimensions(2); tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); @@ -55,8 +57,10 @@ class ConvolutionFoldingTest : public HloTestBase { tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); - tf_default_dnums_for_backward_input_.set_batch_dimension(0); - tf_default_dnums_for_backward_input_.set_feature_dimension(3); + tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); + tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); + tf_default_dnums_for_backward_input_.set_input_feature_dimension(3); + tf_default_dnums_for_backward_input_.set_output_feature_dimension(3); tf_default_dnums_for_backward_input_.add_spatial_dimensions(1); tf_default_dnums_for_backward_input_.add_spatial_dimensions(2); tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3); @@ -250,8 +254,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { conv_window.mutable_dimensions(i)->set_padding_high(3); } ConvolutionDimensionNumbers conv_dnums; - conv_dnums.set_batch_dimension(0); - conv_dnums.set_feature_dimension(1); + conv_dnums.set_input_batch_dimension(0); + conv_dnums.set_output_batch_dimension(0); + conv_dnums.set_input_feature_dimension(1); + conv_dnums.set_output_feature_dimension(1); conv_dnums.add_spatial_dimensions(2); conv_dnums.add_spatial_dimensions(3); conv_dnums.set_kernel_input_feature_dimension(0); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 89145a9038c23e02b1b25140ff3711dc44185d0c..536b96dcf620e908e25a775bc2efb57ba5f5edd6 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -141,8 +141,8 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( BatchDescriptor input_descriptor(effective_num_dimensions); input_descriptor.set_layout(DataLayout::kBatchDepthYX) .set_feature_map_count( - input_shape_.dimensions(dim_nums_.feature_dimension())) - .set_count(input_shape_.dimensions(dim_nums_.batch_dimension())); + input_shape_.dimensions(dim_nums_.input_feature_dimension())) + .set_count(input_shape_.dimensions(dim_nums_.input_batch_dimension())); for (int dim = 0; dim < num_dimensions; ++dim) { // Note that the dimensions are reversed. The same holds below. input_descriptor.set_spatial_dim( @@ -176,8 +176,8 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( BatchDescriptor output_descriptor(effective_num_dimensions); output_descriptor.set_layout(DataLayout::kBatchDepthYX) .set_feature_map_count( - output_shape_.dimensions(dim_nums_.feature_dimension())) - .set_count(output_shape_.dimensions(dim_nums_.batch_dimension())); + output_shape_.dimensions(dim_nums_.output_feature_dimension())) + .set_count(output_shape_.dimensions(dim_nums_.output_batch_dimension())); for (int dim = 0; dim < num_dimensions; ++dim) { output_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), @@ -256,9 +256,9 @@ tensorflow::Status ConvolutionThunk::Convolve( algorithm_config.algorithm_no_scratch().algo_id()); } -std::vector ConvolutionThunk::GetAlgorithms( +std::vector ConvolutionThunk::GetAlgorithms( se::StreamExecutor* stream_exec) const { - std::vector algorithms; + std::vector algorithms; // TODO(yangzihao): Currently disable the use of winograd nonfused in XLA // by default. Should send in conv parameters and enable it when // ShouldIncludeWinogradNonfusedAlgo() returns true. @@ -297,32 +297,27 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( se::dnn::ProfileResult best_result; se::dnn::ProfileResult best_result_without_scratch; - std::vector algorithms = - GetAlgorithms(stream->parent()); - for (bool use_tensor_ops : {false, true}) { - for (auto algo_index : algorithms) { - AlgorithmDesc algorithm(algo_index, use_tensor_ops); - ConvolveScratchAllocator scratch_allocator( - buffer_allocations.device_ordinal(), - buffer_allocations.memory_allocator()); - se::dnn::ProfileResult profile_result; - bool launch_ok = - Convolve(input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, - se::dnn::AlgorithmConfig(algorithm, algorithm), stream, - &scratch_allocator, &profile_result) - .ok(); - if (launch_ok && profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalAllocatedBytes() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_without_scratch.elapsed_time_in_ms()) { - best_result_without_scratch = profile_result; - } + std::vector algorithms = GetAlgorithms(stream->parent()); + for (auto algorithm : algorithms) { + ConvolveScratchAllocator scratch_allocator( + buffer_allocations.device_ordinal(), + buffer_allocations.memory_allocator()); + se::dnn::ProfileResult profile_result; + bool launch_ok = + Convolve(input_descriptor, input_data, filter_descriptor, filter_data, + output_descriptor, output_data, convolution_descriptor, + se::dnn::AlgorithmConfig(algorithm, algorithm), stream, + &scratch_allocator, &profile_result) + .ok(); + if (launch_ok && profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalAllocatedBytes() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_without_scratch.elapsed_time_in_ms()) { + best_result_without_scratch = profile_result; } } } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 509719c1fe555fd733484c82ca14812efca0dcf9..13432301b2af34ab4bd0864e39ce22366cc1d11d 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -115,9 +115,7 @@ class ConvolutionThunk : public Thunk { perftools::gputools::dnn::ProfileResult* profile_result); // Returns the convolve algorithms that can be used for this ConvolutionThunk. - // TODO(nluehr) GetAlgorithms should return AlgorithmDesc including both - // tensor-op and non-tensor-op variants. - std::vector GetAlgorithms( + std::vector GetAlgorithms( perftools::gputools::StreamExecutor* stream_exec) const; // Fastest cuDNN convolution algorithm for this thunk learned from diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 8c1544007e911f48eeaf7b8db21523be083dfe40..3e16e4e3c42cebce75b4e4e95fd7c6477fb230ae 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -67,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -84,6 +85,8 @@ namespace gpu { namespace { +using tensorflow::strings::StrCat; + // Any address of a variable residing in global memory or returned by one of the // memory allocation routines from the driver or runtime API is always aligned // to at least 256 bytes. @@ -223,7 +226,7 @@ tensorflow::Status PrepareHloModuleForIrEmitting( } // Invokes the ptxas tool on the given PTX string, and dumps its output. -void DumpPtxasInfo(const string& ptx) { +void DumpPtxasInfo(const string& ptx, int cc_major, int cc_minor) { const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas"); // Do not log PTX stats if ptxas is not found at the given path. @@ -245,17 +248,22 @@ void DumpPtxasInfo(const string& ptx) { // Invoke ptxas and collect its output. tensorflow::SubProcess ptxas_info_dumper; - ptxas_info_dumper.SetProgram(ptxas_path, {ptxas_path, ptx_path, "-o", - "/dev/null", "-v", "-arch=sm_35"}); + ptxas_info_dumper.SetProgram(ptxas_path, + {ptxas_path, ptx_path, "-o", "/dev/null", "-v", + StrCat("-arch=sm_", cc_major, cc_minor)}); ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); - CHECK(ptxas_info_dumper.Start()); + if (!ptxas_info_dumper.Start()) { + LOG(ERROR) << "Failed to launch ptxas."; + return; + } string stderr_output; int exit_status = ptxas_info_dumper.Communicate( /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); XLA_LOG_LINES(tensorflow::INFO, stderr_output); if (exit_status != 0) { - LOG(FATAL) << "Invalid PTX. See the error message above for reasons."; + LOG(ERROR) << "ptxas exited with non-zero error code " << exit_status + << "."; } } @@ -324,7 +332,6 @@ StatusOr> GpuCompiler::Compile( HloComputation* entry_computation = module->entry_computation(); IrEmitterUnnested ir_emitter(module->config(), entry_computation, - module->config().has_hybrid_result(), &ir_emitter_context); TF_RETURN_IF_ERROR( entry_computation->root_instruction()->Accept(&ir_emitter)); @@ -341,6 +348,16 @@ StatusOr> GpuCompiler::Compile( XLA_VLOG_LINES(2, ir_module_string_before_opt); } + const string& ir_dump_directory = + module->config().debug_options().xla_dump_ir_to(); + + if (!ir_dump_directory.empty()) { + TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( + /*directory_name=*/ir_dump_directory, + /*hlo_module_name=*/module->name(), llvm_module, + /*optimized=*/false)); + } + // Reserve space for the PTX to be generated for this module. string* ptx; { @@ -363,6 +380,13 @@ StatusOr> GpuCompiler::Compile( TF_ASSIGN_OR_RETURN(*ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, module->config(), libdevice_dir_)); + if (!ir_dump_directory.empty()) { + TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( + /*directory_name=*/ir_dump_directory, + /*hlo_module_name=*/module->name(), llvm_module, + /*optimized=*/true)); + } + if (user_post_optimization_hook_) { TF_CHECK_OK(user_post_optimization_hook_(llvm_module)); } @@ -371,7 +395,7 @@ StatusOr> GpuCompiler::Compile( VLOG(2) << "PTX:"; XLA_VLOG_LINES(2, *ptx); if (VLOG_IS_ON(2)) { - DumpPtxasInfo(*ptx); + DumpPtxasInfo(*ptx, cc_major, cc_minor); } auto thunk_schedule = MakeUnique( @@ -392,7 +416,7 @@ StatusOr> GpuCompiler::Compile( StatusOr>> GpuCompiler::Compile( std::vector> modules, - std::vector stream_execs) { + std::vector> stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on GPU."); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index b5ffeef44ff5993cf822263680baab8deadd1799..58e835e5ee3f77b7b5cb3579514b7501bed2a2a1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -46,7 +46,8 @@ class GpuCompiler : public LLVMCompiler { StatusOr>> Compile( std::vector> modules, - std::vector stream_exec) override; + std::vector> + stream_execs) override; StatusOr>> CompileAheadOfTime(std::vector> module, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index cae310861908f0cee96a8b8dbf84e3568082ca6b..2c4d5150741d75ec2d1cb7e3d41c07ad24f800b0 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -184,9 +184,6 @@ StatusOr GpuExecutable::ExecuteOnStream( HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - // This ExecuteOnStream overload should only be called if has_hybrid_result is - // false. - TF_RET_CHECK(!module_config().has_hybrid_result()); BufferAllocations::Builder buffer_allocations_builder; for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); @@ -264,9 +261,6 @@ StatusOr> GpuExecutable::ExecuteOnStream( tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - // This ExecuteOnStream overload should only be called by the LocalService - // which sets has_hybrid_result to true. - TF_RET_CHECK(module_config().has_hybrid_result()); if (GetRootPointsToSet().IsAmbiguous()) { return Unimplemented("Points-to set of root instruction is ambiguous"); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 373c1aa5f9582fdb5f03f17f8a90a5e640f7b54d..152d226ab05ebb7342483ac127bb6ee16913face 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -67,7 +67,7 @@ void HloToIrBindings::EmitBasePointersForHlos( // Lookup allocation GetTupleElement operand. const BufferAllocation::Slice slice = buffer_assignment_ - ->GetUniqueTopLevelSlice(LatestNonGteAncestor(non_io_hlo)) + ->GetUniqueTopLevelSlice(non_io_hlo->LatestNonGteAncestor()) .ConsumeValueOrDie(); // We are not in a nested context, so check non-thread-local allocation. CHECK(!slice.allocation()->is_thread_local()); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 0b94594f1dc5cd040846eabaad01b4cd09520e12..9a4bfd0905bb62c02c70e7f2eea46872c07bca89 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -152,8 +152,10 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { conv_window_col->set_padding_high(1); ConvolutionDimensionNumbers conv_dnums; - conv_dnums.set_batch_dimension(0); - conv_dnums.set_feature_dimension(1); + conv_dnums.set_input_batch_dimension(0); + conv_dnums.set_output_batch_dimension(0); + conv_dnums.set_input_feature_dimension(1); + conv_dnums.set_output_feature_dimension(1); conv_dnums.add_spatial_dimensions(2); conv_dnums.add_spatial_dimensions(3); conv_dnums.set_kernel_output_feature_dimension(0); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 6be26dde8f957040c73db6a7e52f050e44d44c06..8fb7a6adda9dc7c36eb9aabcbcdc9d77e6c22c4a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -214,12 +214,5 @@ llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, value->getType()); } -const HloInstruction* LatestNonGteAncestor(const HloInstruction* hlo) { - while (hlo->opcode() == HloOpcode::kGetTupleElement) { - hlo = hlo->operand(0); - } - return hlo; -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 422972762ee3da793852429a71b4cee76e41e2bc..06c3205296e4546e39525ec093cc17e2fc375d0d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -53,10 +53,6 @@ llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); -// Resolves GetTupleElement instruction operands starting with 'hlo'. -// Returns the first ancestor instruction which is not a GetTupleElement. -const HloInstruction* LatestNonGteAncestor(const HloInstruction* hlo); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index a76d217cac271bcda950c1c325f67810dd513383..3862c2190b1e2df824fa90eafc62bfdfe94e4789 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -34,7 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 2f6b3514497bff386d9f3e6f0d6c9737e8da4871..5e3f3bfdf18bdd5b4f8d0e565d1bb2613cebc3a1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -218,7 +218,6 @@ class IrEmitterUnnested : public IrEmitter { public: IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, - bool has_hybrid_result, IrEmitterContext* ir_emitter_context); IrEmitterUnnested(const IrEmitterUnnested&) = delete; IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; @@ -366,10 +365,6 @@ class IrEmitterUnnested : public IrEmitter { // The HloComputation that this IrEmitter emits code for. const HloComputation* hlo_computation_; - - // Whether this computation will produce a hybrid result, that is the - // computation produces a ShapedBuffer. - bool has_hybrid_result_; }; // Emits LLVM IR for a nested computation to the resultant function. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 958408e875fad904caccfd993e625d1c7b365fc5..120d50ed2582c43816a5e2ac757710cff13f43b7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -132,11 +133,9 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, - bool has_hybrid_result, IrEmitterContext* ir_emitter_context) : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false), - hlo_computation_(hlo_computation), - has_hybrid_result_(has_hybrid_result) { + hlo_computation_(hlo_computation) { // Initialize thunk_sequence_ to an empty list of thunks. thunk_sequence_.reset(new ThunkSequence()); } @@ -256,46 +255,6 @@ Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution, rhs_instruction, window); } -namespace { - -// Returns the first non-GetTupleElement ancestor instruction of 'hlo'. -// If the first non-GTE ancestor is tuple-shaped, populates 'index' with the -// (possibly nested) tuple indices used on the path from ancestor to 'hlo'. -const HloInstruction* LatestNonGteAncestorAndIndex(const HloInstruction* hlo, - ShapeIndex* index) { - if (hlo->opcode() == HloOpcode::kGetTupleElement) { - const auto* operand = LatestNonGteAncestorAndIndex(hlo->operand(0), index); - index->push_back(hlo->tuple_index()); - return operand; - } - return hlo; -} - -// Checks if we can emit code for DynamicUpdateSlice to update data in-place. -// Returns true if operand 0 of DynamicUpdateSlice and its output buffer -// share the same buffer allocation. -// Returns false otherwise. -bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment, - HloInstruction* fusion) { - CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); - HloInstruction* fused_root = fusion->fused_expression_root(); - if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice) { - return false; - } - // Walk DynamicUpdateSlice operand(0) to fused parameter and get its - // associated operand. See if it shares an allocation with this operand. - ShapeIndex index; - auto* fusion_operand = - LatestNonGteAncestorAndIndex(fused_root->operand(0), &index); - if (fusion_operand->opcode() != HloOpcode::kParameter) { - return false; - } - auto* operand = fusion->operand(fusion_operand->parameter_number()); - return assignment.SharesSliceAtIndex(fusion, {}, operand, index); -} - -} // namespace - Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); // HandleFusion specializes reduction from a multi-dimensional array to a 1D @@ -366,95 +325,40 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { LOG(FATAL) << "Bad opcode for input fusion: " << fusion->fused_expression_root()->opcode(); } - } else if (HloInstruction::FusionKind::kLoop == fusion->fusion_kind() && - root->opcode() == HloOpcode::kDynamicUpdateSlice && - CanUpdateDynamicSliceInPlace( - ir_emitter_context_->buffer_assignment(), fusion)) { - // Loop fusion instruction with DynamicUpdateSlice as fused root. - // DynamicUpdateSlice's operand(0) and 'fusion' output share the same - // BufferAllocation::Slice, so it is safe to emit code to update the slice - // 'in-place'. This avoids copying data outside of the slice update region. + } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace( + fusion, ir_emitter_context_->buffer_assignment())) { + // Fusion node with dynamic-update-slice as the root where the op's input + // (i.e. array to update) shares the same slice as its output. In this case + // we have a special algorithm that modifies the output in place without + // touching the un-updated elements. // Set up kernel thunk and fused ir emitter. thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); - std::vector parameter_arrays; + std::vector operand_arrays; for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand)); + operand_arrays.push_back(GetIrArray(*operand)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), &ir_builder_, GetNestedComputer()); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); - TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - - // Recursively lookup 'fusion_operand' for DynamicUpdateSlice operand 0. - auto* fusion_operand = LatestNonGteAncestor(root->operand(0)); - CHECK_EQ(HloOpcode::kParameter, fusion_operand->opcode()); - - // Operand(0) the input array which shares an allocation with the output. - const auto* input = root->operand(0); - llvm::Value* input_base_ptr = fused_emitter.GetIrValueForGTE(input); - // Operand(1) 'update' is slice with which to update input at operand(0). - const auto* update = root->operand(1); - Shape update_shape = update->shape(); - TF_RETURN_IF_ERROR( - LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape)); - // Operand(2) the dynamic slice indices at which to write 'update'. - const auto* start_indices = root->operand(2); - - // Create element generators for 'update' and 'start_indices'. - llvm_ir::ElementGenerator element_generator = - fused_emitter.GetGenerator(update); - llvm_ir::ElementGenerator start_generator = - fused_emitter.GetGenerator(start_indices); - - // Create loop body emitter which emits code to do the following: - // *) Read dynamic slice start indices into 'start_index'. - // *) Map requested 'index' and slice 'start_index' to input/output shape - // as 'output_index'. - // *) Reads value from 'update' element generator. - // *) Writes value to input/output array at 'output_index'. - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& index) -> Status { - // Emit IR to read dynamic start indices from hlo->operand(2). - const int64 rank = ShapeUtil::Rank(input->shape()); - llvm_ir::IrArray::Index start_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index({ir_builder_.getInt64(i)}); - TF_ASSIGN_OR_RETURN(start_index[i], start_generator(dim_index)); - } - // Calculate 'output_index' at which to write value from update. - llvm_ir::IrArray::Index output_index(rank); - for (int64 i = 0; i < rank; ++i) { - // Emit IR which computes: - // output_index = (start_index + index) % dim_size - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), input->shape().dimensions(i)); - llvm::Value* start_index0 = ir_builder_.CreateZExtOrBitCast( - start_index[i], index[i]->getType()); - output_index[i] = ir_builder_.CreateURem( - ir_builder_.CreateAdd(start_index0, index[i]), dim_size); - } + // Shape of the dynamic-update-slice's "update" operand. + Shape update_shape = root->operand(1)->shape(); - // Read value from 'update'. - TF_ASSIGN_OR_RETURN(llvm::Value * input_value, element_generator(index)); - // Write value to output array. - llvm_ir::IrArray(input_base_ptr, input->shape()) - .EmitWriteArrayElement(output_index, input_value, &ir_builder_); - return Status::OK(); - }; + // Array to write into. Because this is an in-place operation, this is the + // same as operand 0's array. + llvm_ir::IrArray output_array = GetIrArray(*fusion); - // Create loop which iterates over 'update' shape. LaunchDimensions launch_dimensions = CalculateLaunchDimensions( update_shape, ir_emitter_context_->device_description()); CHECK(Thunk::Kind::kKernel == LastThunk()->kind()); UpdateLaunchDimensions(launch_dimensions, static_cast(LastThunk()), ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, update_shape, - launch_dimensions, &ir_builder_) - .EmitLoop(IrName(fusion)); + + return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( + fusion, operand_arrays, output_array, &elemental_emitter, + launch_dimensions, &ir_builder_); } if (ImplementedAsGemm(*fusion)) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); @@ -1372,13 +1276,6 @@ Status IrEmitterUnnested::HandleTuple( tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } - // If `inst` is a nested thunk that can be disassembled from the result tuple, - // GpuExecutable will disassemble it and return it as part of the resultant - // ShapedBuffer. - if (has_hybrid_result_ && - ReachRootViaOnlyTuples(*tuple, *hlo_computation_->root_instruction())) { - return Status::OK(); - } thunk_sequence_->emplace_back(BuildKernelThunk(tuple)); return IrEmitter::HandleTuple(tuple, operands); } @@ -1634,7 +1531,7 @@ llvm::Function* IrEmitterUnnested::EmitBasePointersForHloAndItsOperands( // with their operand buffer in 'io_hlos' and 'non_io_hlos' below. std::vector non_io_hlos; for (const HloInstruction* operand : hlo.operands()) { - const HloInstruction* to_lookup = LatestNonGteAncestor(operand); + const HloInstruction* to_lookup = operand->LatestNonGteAncestor(); if (buffer_assignment.HasTopLevelAllocation(to_lookup) && buffer_assignment.GetUniqueTopLevelSlice(to_lookup) .ConsumeValueOrDie() @@ -1674,7 +1571,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( std::vector io_buffers; io_buffers.reserve(io_hlos.size()); for (const HloInstruction* io_hlo : io_hlos) { - io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo))); + io_buffers.push_back(GetAllocationSlice(*io_hlo->LatestNonGteAncestor())); } // Create a KernelThunk that launches the kernel that implements "inst". @@ -1888,14 +1785,12 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( // Generate thunk sequence for while 'condition'. HloComputation* condition = hlo->while_condition(); IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition, - /*has_hybrid_result=*/false, ir_emitter_context_); TF_CHECK_OK(condition->root_instruction()->Accept(&ir_emitter_condition)); // Generate thunk sequence for while 'body'. HloComputation* body = hlo->while_body(); IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, - false /* has_hybrid_result */, ir_emitter_context_); TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); @@ -1914,7 +1809,6 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( // Generate thunk sequence for while 'body' (will be used a For loop body). HloComputation* body = hlo->while_body(); IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, - false /* has_hybrid_result */, ir_emitter_context_); TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc index b0480e2f475c7b295b88d36a91ae08a90b818085..0bbd63fb7bfc657cb7bb1de673253c198f5bd25f 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc @@ -84,8 +84,8 @@ Status GpuLayoutAssignment::AddBackendConstraints( --i) { input_layout.push_back(dimension_numbers.spatial_dimensions(i)); } - input_layout.push_back(dimension_numbers.feature_dimension()); - input_layout.push_back(dimension_numbers.batch_dimension()); + input_layout.push_back(dimension_numbers.input_feature_dimension()); + input_layout.push_back(dimension_numbers.input_batch_dimension()); Shape input_shape(input->shape()); *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); @@ -106,8 +106,8 @@ Status GpuLayoutAssignment::AddBackendConstraints( --i) { output_layout.push_back(dimension_numbers.spatial_dimensions(i)); } - output_layout.push_back(dimension_numbers.feature_dimension()); - output_layout.push_back(dimension_numbers.batch_dimension()); + output_layout.push_back(dimension_numbers.output_feature_dimension()); + output_layout.push_back(dimension_numbers.output_batch_dimension()); Shape output_shape(output->shape()); *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index af853385d634b06d31cef94216fb4059dfcadc3d..79493c4112804f8454d200f3f83aa85d718f0d0a 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -39,6 +39,8 @@ message HloInstructionProto { string name = 1; string opcode = 2; xla.Shape shape = 3; + + // TODO(b/67782397): Replace instruction names with HloInstruction ids. repeated string operand_names = 4; repeated string control_predecessor_names = 5; repeated string called_computation_names = 6; @@ -58,6 +60,64 @@ message HloInstructionProto { // Index for kGetTupleElement. int64 tuple_index = 13; + + // Dimensions present for some operations that require reshaping or + // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. + repeated int64 dimensions = 14; + + // Describes the window in a windowed operation such as convolution. + xla.Window window = 15; + + // Describes the dimension numbers used for a convolution. + xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; + + // Describes the [begin, end) index range and stride for slices. + message SliceDimensions { + int64 start = 1; + int64 limit = 2; + int64 stride = 3; + } + repeated SliceDimensions slice_dimensions = 17; + + // The bit sizes for a reduce-precision operation. + int32 exponent_bits = 18; + int32 mantissa_bits = 19; + + // Describes the [start, start + size) range size for a dynamic slice + // ('start' is specified dynamically in the second operand of the operation). + repeated int64 dynamic_slice_sizes = 20; + + // The padding configuration that describes the edge padding and interior + // padding of this pad instruction. Only set for pad instructions. + xla.PaddingConfig padding_config = 21; + + // Outfeed configuration information, only present for kOutfeed. + bytes outfeed_config = 22; + + // The distribution requested for random number generation. + // Only present for kRng. + xla.RandomDistribution distribution = 23; + + // A small float number added to the variance to avoid divide-by-zero error. + // Only present for kBatchNormTraining. + float epsilon = 24; + + // An integer value representing the index of the feature dimension. + // Only present for kBatchNormTraining. + int64 feature_index = 25; + + // Represents a unique identifier for each Send/Recv instruction pair. + // Only present for kSend or kRecv. + int64 channel_id = 26; + + // The string representation of the infeed configuration. + bytes infeed_config = 27; + + // Name of a global symbol to call, only present for kCustomCall. + string custom_call_target = 28; + + // Shape of outfeed request. + xla.Shape outfeed_shape = 29; } // Serialization of HloComputation. @@ -67,6 +127,9 @@ message HloComputationProto { // The array of instructions is always in a valid dependency order, where // operands appear before their users. repeated HloInstructionProto instructions = 2; + + // The name of the root of the computation. + string root_name = 3; } // Serialization of HloModule. @@ -187,3 +250,7 @@ message HloProto { HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } + +message HloProtos { + repeated HloProto hlo_protos = 1; +} diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 444104d88fe34600e7dc38fc806ea34b74660da8..9b3104eaacdbb083db2a55c75fae3e94c8ff282f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -400,9 +400,38 @@ HloComputationProto HloComputation::ToProto() const { HloInstructionProto instruction_proto = instruction->ToProto(); proto.add_instructions()->Swap(&instruction_proto); } + proto.set_root_name(root_instruction()->name()); return proto; } +/* static */ StatusOr> +HloComputation::CreateFromProto( + HloModule* module, const HloComputationProto& proto, + tensorflow::gtl::FlatMap* computation_map, + HloInstruction* fusion_instruction) { + std::vector> instructions; + tensorflow::gtl::FlatMap instruction_map; + int64 parameter_count = 0; + for (const HloInstructionProto& instruction_proto : proto.instructions()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr instruction, + HloInstruction::CreateFromProto(module, instruction_proto, + instruction_map, computation_map)); + if (instruction->opcode() == HloOpcode::kParameter) { + parameter_count++; + } + TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name())); + instruction_map[instruction->name()] = instruction.get(); + instructions.push_back(std::move(instruction)); + } + + TF_RET_CHECK(!proto.root_name().empty()); + TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name())); + HloInstruction* root = instruction_map.at(proto.root_name()); + return WrapUnique(new HloComputation( + proto.name(), parameter_count, &instructions, root, fusion_instruction)); +} + void HloComputation::FuseInstructionsInto( tensorflow::gtl::ArraySlice instructions_to_fuse, HloInstruction* fusion_instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index b929b41bad403e76948e00fa627829565ed324fd..3515a6b5df2ed9a77bdf611adfbf14536aed8348 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -143,6 +143,22 @@ class HloComputation { // Returns a serialized representation of this computation. HloComputationProto ToProto() const; + // Creates a computation from the given proto. Arguments: + // + // module: the module which will contain the computation. The newly created + // computation is *not* added to the module, however. + // proto: the proto to convert from. + // computation_map: a map from computation name to HloComputation*. This map + // must contain all computations which the newly constructed computation + // calls. + // fusion_instruction: if non-null then the newly created computation will be + // constructed as a fused computation with this instruction as its fusion + // parent. + static StatusOr> CreateFromProto( + HloModule* module, const HloComputationProto& proto, + tensorflow::gtl::FlatMap* computation_map, + HloInstruction* fusion_instruction = nullptr); + // Gets the instructions in this computation. // // The returned type is a range of HloInstruction*s, so you can iterate over diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index ccab7bf34862f3303db1331a87b5c70fdc3283ba..7b7588f4ba9aa622677db6f9d5022cc8cc029e04 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -310,7 +310,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { } TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { - // Test that DeepCopyInstruction properly copies elements of a a tuple as + // Test that DeepCopyInstruction properly copies elements of a tuple as // specified by the given indices. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 65725ca692fb3429106f5ed50f4a2c11bd46f54c..84d55d4b5f83bd54940d3011037598deb6ec934b 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -393,7 +393,7 @@ Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution, const Window& window) { const auto& dnums = convolution->convolution_dimension_numbers(); const int64 output_features = - convolution->shape().dimensions(dnums.feature_dimension()); + convolution->shape().dimensions(dnums.output_feature_dimension()); // For each output element, we do one fma per element in the kernel at some // given output feature index. diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 71321e5e9ae733ea06d7988e2a301bf838022563..a4921232f5848dbe1789c4c641e2b0ba3c1848bb 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -64,6 +64,29 @@ StatusOr HloDCE::Run(HloModule* module) { } } + // Now DCE HloComputations. First, collect the computations that are + // referenced by some remaining instruction. + std::unordered_set live_computations; + if (HloComputation* entry_computation = module->entry_computation()) { + live_computations.insert(entry_computation); + } + for (auto* computation : module->MakeComputationPostOrder()) { + for (auto* instruction : computation->instructions()) { + for (auto* subcomp : instruction->called_computations()) { + live_computations.insert(subcomp); + } + } + } + + // Remove dead computations. + std::list computations = module->MakeComputationPostOrder(); + for (auto* computation : computations) { + if (live_computations.count(computation) == 0) { + TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); + changed = true; + } + } + return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index fa0ab98649abf0c452ca695af34d1c6f6ce116f5..d54b9a27087a42fd23eab0bd06e8deaca567312b 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -299,5 +299,93 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { EXPECT_TRUE(HasInstruction(*computation, live_call)); } +TEST_F(HloDceTest, RemoveDeadSubcomputation) { + auto module = CreateNewModule(); + HloComputation::Builder builder(TestName()); + + HloComputation::Builder subcomp_builder("reduction_subcomp"); + { + auto* param0 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0")); + auto* param1 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1")); + subcomp_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1)); + } + auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build()); + + // Create a dead reduce instruction. + builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(F32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{0}, reduce_subcomp)); + + // Add another instruction as the root of the computation. + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))); + + module->AddEntryComputation(builder.Build()); + EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); + + HloDCE dce; + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + + // We should have DCE'ed the reduction computation along with the reduction + // instruction. + EXPECT_EQ(module->MakeComputationPostOrder().size(), 1); +} + +TEST_F(HloDceTest, KeepUsedSubcomputation) { + auto module = CreateNewModule(); + HloComputation::Builder builder(TestName()); + + HloComputation::Builder subcomp_builder("reduction_subcomp"); + { + auto* param0 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0")); + auto* param1 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1")); + subcomp_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1)); + } + auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build()); + + // Create a dead reduce instruction. + builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(F32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{0}, reduce_subcomp)); + + // Add another instruction as the root of the computation that also uses + // reduce_subcomp. + builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(F32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{0}, reduce_subcomp)); + + module->AddEntryComputation(builder.Build()); + EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); + + HloDCE dce; + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + + // We shouldn't have DCE'ed reduce_subcomp, even though we removed one of + // its users. + EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 4f9d6c00961d3027d07f2581ca410f88b6b2dad8..5fd891835d7bad0218c1d478f866d97bdf9dd7ca 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -255,12 +255,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleLogicalNot(HloInstruction* logical_not, - HloInstruction* operand) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[logical_not], - ElementWiseUnaryOp(logical_not, - [](ReturnT elem_operand) { return !elem_operand; })); + Status HandleNot(HloInstruction* not_, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + return !elem_operand; + })); return Status::OK(); }; @@ -368,26 +367,113 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleAnd(HloInstruction* and_, HloInstruction* lhs, + HloInstruction* rhs) override { TF_ASSIGN_OR_RETURN( - parent_->evaluated_[logical_and], - ElementWiseBinaryOp(logical_and, [](ReturnT lhs_el, ReturnT rhs_el) { + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) { return lhs_el && rhs_el; })); return Status::OK(); }; - Status HandleLogicalOr(HloInstruction* logical_or, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleOr(HloInstruction* or_, HloInstruction* lhs, + HloInstruction* rhs) override { TF_ASSIGN_OR_RETURN( - parent_->evaluated_[logical_or], - ElementWiseBinaryOp(logical_or, [](ReturnT lhs_el, ReturnT rhs_el) { + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) { return lhs_el || rhs_el; })); return Status::OK(); }; + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shl], + ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { + return lhs_elem << rhs_elem; + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftLeft"); + } + + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftLeft(shl, lhs, rhs); + } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + typedef typename std::make_signed::type SignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + return static_cast(static_cast(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + } + + Status HandleShiftRightArithmetic(HloInstruction* shra, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftRightArithmetic(shra, lhs, rhs); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + typedef typename std::make_unsigned::type UnsignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + return static_cast(static_cast(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftRightLogical"); + } + + Status HandleShiftRightLogical(HloInstruction* shrl, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftRightLogical(shrl, lhs, rhs); + } + Status HandleClamp(HloInstruction* clamp, HloInstruction* min, HloInstruction* arg, HloInstruction* max) override { std::function clamp_op = @@ -481,14 +567,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - // Dimension number applicable for both input (lhs), and output. - const int64 batch_dim = dnums.batch_dimension(); - const int64 z_dim = dnums.feature_dimension(); + // Dimension number applicable for input (lhs). + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_z_dim = dnums.input_feature_dimension(); // Dimension number applicable for kernel (rhs). const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); + // Dimension number applicable for output. + const int64 output_batch_dim = dnums.output_batch_dimension(); + const int64 output_z_dim = dnums.output_feature_dimension(); - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, z_dim); + const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); std::vector window_dimension_sizes; for (auto i : dnums.kernel_spatial_dimensions()) { @@ -509,13 +598,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::fill(rhs_index.begin(), rhs_index.end(), 0); std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0); - lhs_index[batch_dim] = out_index[batch_dim]; - rhs_index[kernel_output_z_dim] = out_index[z_dim]; + lhs_index[input_batch_dim] = out_index[output_batch_dim]; + rhs_index[kernel_output_z_dim] = out_index[output_z_dim]; // Convolve input feature with kernel. do { for (int64 iz = 0; iz < z_size; ++iz) { - lhs_index[z_dim] = iz; + lhs_index[input_z_dim] = iz; rhs_index[kernel_input_z_dim] = iz; // Find corresponding spatial dimension index for input (lhs). @@ -1241,8 +1330,14 @@ StatusOr> HloEvaluator::Evaluate( StatusOr> HloEvaluator::Evaluate( HloInstruction* instruction) { - TF_RET_CHECK(hlo_query::AllOperandsAreConstants(*instruction)); - TF_RET_CHECK(instruction->opcode() != HloOpcode::kParameter); + if (instruction->opcode() == HloOpcode::kParameter) { + return tensorflow::errors::FailedPrecondition( + "Cannot evaluate a parameter."); + } + if (!hlo_query::AllOperandsAreConstants(*instruction)) { + return tensorflow::errors::FailedPrecondition( + "Not all operands are constants."); + } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); arg_literals_.clear(); @@ -1285,8 +1380,17 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( operands.push_back(operand.get()); } - return Evaluate( - instruction->CloneWithNewOperands(instruction->shape(), operands).get()); + std::unique_ptr cloned_instruction = + instruction->CloneWithNewOperands(instruction->shape(), operands); + auto result = Evaluate(cloned_instruction.get()); + + // Clean up our cloned instructions before returning. + cloned_instruction->DetachFromOperands(); + for (auto& operand : owned_operands) { + operand->DetachFromOperands(); + } + + return result; } Status HloEvaluator::HandleParameter(HloInstruction* parameter) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index a8a73e866ee08600dcdf58d7618b30514a2b4ca1..5172739624861972a32802a5148032eb83f6cda6 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -736,8 +736,10 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.set_kernel_output_feature_dimension(0); @@ -868,8 +870,10 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(2); - dnums.set_feature_dimension(0); + dnums.set_input_batch_dimension(2); + dnums.set_output_batch_dimension(2); + dnums.set_input_feature_dimension(0); + dnums.set_output_feature_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(3); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 9b4a2f1048cb0644e6ba81e4e13115b608e4fcc0..24e390529e5cd02a4bb40d7aa861e852254fe253 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -777,9 +777,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalNot: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -789,6 +789,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSlice: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 99bec2c0beaf0271c8e188831c4fe2d6d250f1b0..021e5881c8af17de747b3189e7aae1d620a1035c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -47,6 +47,101 @@ using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; +/* static */ +StatusOr> HloInstruction::CreateFromProto( + HloModule* module, const HloInstructionProto& proto, + const tensorflow::gtl::FlatMap& instruction_map, + tensorflow::gtl::FlatMap* computation_map) { + TF_RET_CHECK(!proto.opcode().empty()); + TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); + TF_RET_CHECK(proto.has_shape()); + + auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + for (const string& operand_name : proto.operand_names()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_name)) + << "No instruction named " << operand_name; + instruction->AppendOperand(instruction_map.at(operand_name)); + } + for (const string& predecessor_name : proto.control_predecessor_names()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name)) + << "No instruction named " << predecessor_name; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name) + ->AddControlDependencyTo(instruction.get())); + } + + // In the proto, fused computations are held exclusively within the + // HloInstructionProto and do not appear as an HloComputationProto within the + // HloModuleProto. + if (instruction->opcode() == HloOpcode::kFusion) { + TF_RET_CHECK(proto.has_fused_instructions_computation()); + TF_RET_CHECK(!proto.fusion_kind().empty()); + TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, + StringToFusionKind(proto.fusion_kind())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr fused_computation, + HloComputation::CreateFromProto( + module, proto.fused_instructions_computation(), computation_map, + /*fusion_instruction=*/instruction.get())); + instruction->called_computations_.push_back( + module->AddEmbeddedComputation(std::move(fused_computation))); + } else { + for (const string& computation_name : proto.called_computation_names()) { + TF_RET_CHECK(ContainsKey(*computation_map, computation_name)) + << "No computation named " << computation_name; + instruction->called_computations_.push_back( + computation_map->at(computation_name)); + } + } + + TF_RET_CHECK(!proto.name().empty()); + instruction->name_ = proto.name(); + + instruction->metadata_ = proto.metadata(); + if (proto.has_literal()) { + instruction->literal_ = MakeUnique(proto.literal()); + } + instruction->parameter_number_ = proto.parameter_number(); + instruction->parameter_name_ = proto.parameter_name(); + + instruction->tuple_index_ = proto.tuple_index(); + for (int64 dimension : proto.dimensions()) { + instruction->dimensions_.push_back(dimension); + } + if (proto.has_window()) { + instruction->window_ = MakeUnique(proto.window()); + } + if (proto.has_convolution_dimension_numbers()) { + instruction->convolution_dimension_numbers_ = + MakeUnique( + proto.convolution_dimension_numbers()); + } + for (const HloInstructionProto::SliceDimensions& slice_dimensions : + proto.slice_dimensions()) { + instruction->slice_starts_.push_back(slice_dimensions.start()); + instruction->slice_limits_.push_back(slice_dimensions.limit()); + instruction->slice_strides_.push_back(slice_dimensions.stride()); + } + instruction->exponent_bits_ = proto.exponent_bits(); + instruction->mantissa_bits_ = proto.mantissa_bits(); + for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { + instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size); + } + if (proto.has_padding_config()) { + instruction->padding_config_ = + MakeUnique(proto.padding_config()); + } + instruction->outfeed_config_ = proto.outfeed_config(); + instruction->distribution_ = proto.distribution(); + instruction->epsilon_ = proto.epsilon(); + instruction->feature_index_ = proto.feature_index(); + instruction->channel_id_ = proto.channel_id(); + instruction->infeed_config_ = proto.infeed_config(); + instruction->custom_call_target_ = proto.custom_call_target(); + instruction->outfeed_shape_ = proto.outfeed_shape(); + + return std::move(instruction); +} + /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { auto instruction = @@ -126,7 +221,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kFloor: case HloOpcode::kIsFinite: case HloOpcode::kLog: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kSign: case HloOpcode::kSin: @@ -161,8 +256,11 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case (HloOpcode::kPower): case (HloOpcode::kRemainder): case (HloOpcode::kSubtract): - case (HloOpcode::kLogicalAnd): - case (HloOpcode::kLogicalOr): + case (HloOpcode::kAnd): + case (HloOpcode::kOr): + case (HloOpcode::kShiftLeft): + case (HloOpcode::kShiftRightArithmetic): + case (HloOpcode::kShiftRightLogical): break; default: LOG(FATAL) << "Invalid binary instruction opcode " @@ -879,7 +977,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kSign: case HloOpcode::kSin: @@ -903,8 +1001,11 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kMinimum: case HloOpcode::kPower: case HloOpcode::kRemainder: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: CHECK_EQ(new_operands.size(), 2); return CreateBinary(shape, opcode_, new_operands[0], new_operands[1]); // Ternary ops. @@ -981,8 +1082,11 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kTranspose: CHECK_EQ(new_operands.size(), 1); return CreateTranspose(shape, new_operands[0], dimensions_); - case HloOpcode::kTuple: - return CreateTuple(new_operands); + case HloOpcode::kTuple: { + auto new_tuple = CreateTuple(new_operands); + *new_tuple->mutable_shape() = shape; + return new_tuple; + } case HloOpcode::kWhile: CHECK_EQ(new_operands.size(), 1); return CreateWhile(shape, while_condition(), while_body(), @@ -1131,6 +1235,29 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( return new_instruction; } +std::pair +HloInstruction::LatestNonGteAncestorAndIndex() const { + const HloInstruction* hlo = this; + ShapeIndex index; + while (hlo->opcode() == HloOpcode::kGetTupleElement) { + index.push_back(hlo->tuple_index()); + hlo = hlo->operand(0); + } + + // We built up index in the reverse order from what we want. + std::reverse(index.begin(), index.end()); + + return {hlo, index}; +} + +const HloInstruction* HloInstruction::LatestNonGteAncestor() const { + const HloInstruction* hlo = this; + while (hlo->opcode() == HloOpcode::kGetTupleElement) { + hlo = hlo->operand(0); + } + return hlo; +} + const Literal& HloInstruction::literal() const { CHECK_EQ(HloOpcode::kConstant, opcode_); return *literal_; @@ -1258,9 +1385,9 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalNot: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -1270,6 +1397,9 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSubtract: @@ -1702,6 +1832,10 @@ std::vector HloInstruction::ExtraAttributesToString() const { }))); } + if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) { + extra.push_back(StrCat("channel_id=", channel_id_)); + } + if (opcode() == HloOpcode::kGetTupleElement) { extra.push_back(StrCat("index=", tuple_index())); } @@ -1735,37 +1869,59 @@ HloInstructionProto HloInstruction::ToProto() const { for (const HloInstruction* control : control_predecessors_) { *proto.add_control_predecessor_names() = control->name(); } - for (const HloComputation* computation : called_computations_) { - *proto.add_called_computation_names() = computation->name(); - } + *proto.mutable_metadata() = metadata_; - switch (opcode_) { - case HloOpcode::kConstant: - *proto.mutable_literal() = literal_->ToProto(); - break; - case HloOpcode::kParameter: - proto.set_parameter_number(parameter_number_); - proto.set_parameter_name(parameter_name_); - break; - case HloOpcode::kFusion: { - HloComputationProto* proto_fused_computation = - proto.mutable_fused_instructions_computation(); - proto_fused_computation->set_name(name()); - - // Fill in fused instructions in post order. - auto fused_instructions = - fused_instructions_computation()->MakeInstructionPostOrder(); - for (auto fused_instruction : fused_instructions) { - HloInstructionProto fused_proto = fused_instruction->ToProto(); - proto_fused_computation->add_instructions()->Swap(&fused_proto); - } - break; + if (literal_ != nullptr) { + *proto.mutable_literal() = literal_->ToProto(); + } + proto.set_parameter_number(parameter_number_); + proto.set_parameter_name(parameter_name_); + if (opcode() == HloOpcode::kFusion) { + proto.set_fusion_kind(xla::ToString(fusion_kind())); + *proto.mutable_fused_instructions_computation() = + fused_instructions_computation()->ToProto(); + } else { + for (const HloComputation* computation : called_computations_) { + *proto.add_called_computation_names() = computation->name(); } - case HloOpcode::kGetTupleElement: - proto.set_tuple_index(tuple_index_); - break; - default: {} // Nothing to do } + + proto.set_tuple_index(tuple_index_); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + if (window_ != nullptr) { + *proto.mutable_window() = *window_; + } + if (convolution_dimension_numbers_ != nullptr) { + *proto.mutable_convolution_dimension_numbers() = + *convolution_dimension_numbers_; + } + for (int i = 0; i < slice_starts_.size(); ++i) { + auto* slice_dimension = proto.add_slice_dimensions(); + slice_dimension->set_start(slice_starts_[i]); + slice_dimension->set_limit(slice_limits_[i]); + slice_dimension->set_stride(slice_strides_[i]); + } + proto.set_exponent_bits(exponent_bits_); + proto.set_mantissa_bits(mantissa_bits_); + for (int64 slice_size : dynamic_slice_sizes_) { + proto.add_dynamic_slice_sizes(slice_size); + } + if (padding_config_ != nullptr) { + *proto.mutable_padding_config() = *padding_config_; + } + proto.set_outfeed_config(outfeed_config_); + if (opcode() == HloOpcode::kRng) { + proto.set_distribution(distribution_); + } + proto.set_epsilon(epsilon_); + proto.set_feature_index(feature_index_); + proto.set_channel_id(channel_id_); + proto.set_infeed_config(infeed_config_); + proto.set_custom_call_target(custom_call_target_); + *proto.mutable_outfeed_shape() = outfeed_shape_; + return proto; } @@ -1953,10 +2109,17 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleMaximum(this); case HloOpcode::kMinimum: return visitor->HandleMinimum(this); - case HloOpcode::kLogicalAnd: - return visitor->HandleLogicalAnd(this, operands_[0], operands_[1]); - case HloOpcode::kLogicalOr: - return visitor->HandleLogicalOr(this, operands_[0], operands_[1]); + case HloOpcode::kAnd: + return visitor->HandleAnd(this, operands_[0], operands_[1]); + case HloOpcode::kOr: + return visitor->HandleOr(this, operands_[0], operands_[1]); + case HloOpcode::kShiftLeft: + return visitor->HandleShiftLeft(this, operands_[0], operands_[1]); + case HloOpcode::kShiftRightArithmetic: + return visitor->HandleShiftRightArithmetic(this, operands_[0], + operands_[1]); + case HloOpcode::kShiftRightLogical: + return visitor->HandleShiftRightLogical(this, operands_[0], operands_[1]); case HloOpcode::kConcatenate: return visitor->HandleConcatenate(this, operands_); case HloOpcode::kConvert: @@ -2012,8 +2175,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleSin(this, operands_[0]); case HloOpcode::kIsFinite: return visitor->HandleIsFinite(this, operands_[0]); - case HloOpcode::kLogicalNot: - return visitor->HandleLogicalNot(this, operands_[0]); + case HloOpcode::kNot: + return visitor->HandleNot(this, operands_[0]); case HloOpcode::kBitcast: return visitor->HandleBitcast(this); case HloOpcode::kBroadcast: @@ -2315,8 +2478,11 @@ bool HloInstruction::IsElementwiseBinary() const { case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return true; default: return false; @@ -2340,7 +2506,7 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kFloor: case HloOpcode::kIsFinite: case HloOpcode::kLog: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReducePrecision: case HloOpcode::kSign: @@ -2364,8 +2530,11 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return true; // Ternary elementwise operations. @@ -2584,6 +2753,32 @@ string ToString(HloInstruction::FusionKind kind) { } } +StatusOr StringToFusionKind( + const string& kind_name) { + if (kind_name == "kLoop") { + return HloInstruction::FusionKind::kLoop; + } + if (kind_name == "kInput") { + return HloInstruction::FusionKind::kInput; + } + if (kind_name == "kOutput") { + return HloInstruction::FusionKind::kOutput; + } + if (kind_name == "kTransposeDot") { + return HloInstruction::FusionKind::kTransposeDot; + } + if (kind_name == "kConvBackwardFilter") { + return HloInstruction::FusionKind::kConvBackwardFilter; + } + if (kind_name == "kConvBackwardInput") { + return HloInstruction::FusionKind::kConvBackwardInput; + } + if (kind_name == "kCustom") { + return HloInstruction::FusionKind::kCustom; + } + return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } @@ -2611,8 +2806,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". std::vector lhs_dims(2 + dnums.spatial_dimensions().size()); - lhs_dims[dnums.batch_dimension()] = 'b'; - lhs_dims[dnums.feature_dimension()] = 'f'; + lhs_dims[dnums.input_batch_dimension()] = 'b'; + lhs_dims[dnums.input_feature_dimension()] = 'f'; for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { lhs_dims[dnums.spatial_dimensions(i)] = StrCat(i); } @@ -2624,12 +2819,19 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i); } + std::vector output_dims(2 + dnums.spatial_dimensions().size()); + output_dims[dnums.output_batch_dimension()] = 'b'; + output_dims[dnums.output_feature_dimension()] = 'f'; + for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { + output_dims[dnums.spatial_dimensions(i)] = StrCat(i); + } + result += "dim_labels="; append_dims(lhs_dims, operand(0)->shape()); result += "_"; append_dims(rhs_dims, operand(1)->shape()); result += "->"; - append_dims(lhs_dims, shape()); + append_dims(output_dims, shape()); return result; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 73c4ebd9f11c722c699f3eb6e78703e837fe57db..d2a15b0f962317cc79ab93cf377a77939a5eba41 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -72,6 +72,23 @@ class HloInstruction { }; ~HloInstruction(); + + // Creates an instruction from the given proto. Arguments: + // + // module: the module which will contain the instruction. The newly created + // instruction is *not* added to the module or any computation, however. + // proto: the proto to convert from. + // instruction_map: a map from instruction name to HloInstruction*. This map + // must contain all operands of the newly constructed instruction. + // computation_map: a map from computation name to HloComputation*. This map + // must contain all computations which the newly constructed instruction + // calls. If the instruction is a fusion instruction, then the fusion + // computation is added to this map and the module. + static StatusOr> CreateFromProto( + HloModule* module, const HloInstructionProto& proto, + const tensorflow::gtl::FlatMap& instruction_map, + tensorflow::gtl::FlatMap* computation_map); + // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, const Shape& shape, @@ -508,6 +525,26 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kGetTupleElement int64 tuple_index() const; + // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. + // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the + // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. + std::pair LatestNonGteAncestorAndIndex() + const; + + std::pair LatestNonGteAncestorAndIndex() { + auto rv = + const_cast(this)->LatestNonGteAncestorAndIndex(); + return {const_cast(rv.first), rv.second}; + } + + // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction. + const HloInstruction* LatestNonGteAncestor() const; + + HloInstruction* LatestNonGteAncestor() { + return const_cast( + const_cast(this)->LatestNonGteAncestor()); + } + // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. // The setter should only be called by HloModule or HloComputation methods. // @@ -1055,7 +1092,7 @@ class HloInstruction { std::unique_ptr literal_; // Constant index, only present for kGetTupleElement. - int64 tuple_index_ = 0; + int64 tuple_index_ = -1; // Dimensions present for some operations that require reshaping or // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. @@ -1073,8 +1110,8 @@ class HloInstruction { std::vector slice_strides_; // The bit sizes for a reduce-precision operation. - int32 exponent_bits_; - int32 mantissa_bits_; + int32 exponent_bits_ = 0; + int32 mantissa_bits_ = 0; // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). @@ -1124,11 +1161,11 @@ class HloInstruction { // A small float number added to the variance to avoid divide-by-zero error. // Only present for kBatchNormTraining. - float epsilon_; + float epsilon_ = 0.0f; // An integer value representing the index of the feature dimension. // Only present for kBatchNormTraining. - int64 feature_index_; + int64 feature_index_ = -1; // Represents a unique identifier for each Send/Recv instruction pair. // Only present for kSend or kRecv. @@ -1154,6 +1191,8 @@ class HloInstruction { }; string ToString(HloInstruction::FusionKind kind); +StatusOr StringToFusionKind( + const string& kind_name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 3601d5cdbe66b305b4fa23fa5e1c519704befbb9..45f9128eab766797030f8ab69700d8979e97f918 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -729,6 +729,23 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10)); } +TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { + HloComputation::Builder builder(TestName()); + auto* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({ + {1, 2}, + {3, 4}, + }))); + auto* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {0}) + ->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {1}) + ->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + auto tuple_clone = tuple->Clone(); + EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape())); +} + TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index b1b3dd61a63d8f729912c5b533099f739f9aa9c4..d1ae5f776d281aa4cad157c9e2bc1f2c1133b37f 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -79,9 +79,9 @@ HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); HLO_MATCHER(Le); HLO_MATCHER(Log); -HLO_MATCHER(LogicalAnd); -HLO_MATCHER(LogicalNot); -HLO_MATCHER(LogicalOr); +HLO_MATCHER(And); +HLO_MATCHER(Not); +HLO_MATCHER(Or); HLO_MATCHER(Lt); HLO_MATCHER(Map); HLO_MATCHER(Maximum); @@ -104,6 +104,9 @@ HLO_MATCHER(Rng); HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); +HLO_MATCHER(ShiftLeft); +HLO_MATCHER(ShiftRightLogical); +HLO_MATCHER(ShiftRightArithmetic); HLO_MATCHER(Sign); HLO_MATCHER(Slice); HLO_MATCHER(Sort); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 14590112a1edd16c0c1ab16d9e1d2aac5ce66e18..5bc7a3643936b3cb3ef066b4f741c934f5e850d3 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -167,12 +167,45 @@ HloModuleProto HloModule::ToProto() const { proto.set_name(name_); proto.set_entry_computation_name(entry_computation_->name()); for (const HloComputation* computation : MakeComputationPostOrder()) { + // Fusion computations are added when the fusion instructions are created by + // HloInstruction::CreateFromProto. + if (computation->IsFusionComputation()) { + continue; + } HloComputationProto computation_proto = computation->ToProto(); proto.add_computations()->Swap(&computation_proto); } return proto; } +/* static */ +StatusOr> HloModule::CreateFromProto( + const HloModuleProto& proto, + const VersionedComputationHandle& entry_computation_handle, + const HloModuleConfig& config) { + auto module = + MakeUnique(proto.name(), entry_computation_handle, config); + tensorflow::gtl::FlatMap computation_map; + for (const HloComputationProto& computation_proto : proto.computations()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr computation, + HloComputation::CreateFromProto( + module.get(), computation_proto, &computation_map)); + CHECK_NE(computation.get(), nullptr); + TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); + string computation_name = computation->name(); + if (proto.entry_computation_name() == computation_name) { + computation_map[computation_name] = + module->AddEntryComputation(std::move(computation)); + } else { + computation_map[computation_name] = + module->AddEmbeddedComputation(std::move(computation)); + } + } + TF_RET_CHECK(module->entry_computation_ != nullptr); + + return std::move(module); +} + namespace { // Returns whether `hlo` is used outside the given subcomputation. // `instructions_in_subcomputation` is the instruction set of the given diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 3546f4b3f7982d4bb325888da2438965aa946b7c..96c17d62970d48ccf590c44115c61c89fd379efe 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -140,7 +140,13 @@ class HloModule { const HloModuleConfig& config() const { return config_; } string ToString() const; + + // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; + static StatusOr> CreateFromProto( + const HloModuleProto& proto, + const VersionedComputationHandle& entry_computation_handle, + const HloModuleConfig& config); // Outlines the given expression from the given computation. // instructions_to_outline contains the instructions that form the expression. diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 2299200b5be969c065fded840709a3d6034efe47..4a7ead9c104d2ed50d5c895b3cdf2d3767ae16e8 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -67,11 +67,6 @@ class HloModuleConfig { bool hlo_profiling_enabled() const { return hlo_profiling_enabled_; } void enable_hlo_profiling(bool enabled) { hlo_profiling_enabled_ = enabled; } - bool has_hybrid_result() const { return has_hybrid_result_; } - void set_has_hybrid_result(bool has_hybrid_result) { - has_hybrid_result_ = has_hybrid_result; - } - // Sets/returns the module seed set during execution. void set_seed(uint64 seed) { seed_ = seed; } uint64 seed() const { return seed_; } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 83fe6ef6c967f865333eff51b04a33b1d11ffa7e..db3abeab22044de372c6fb6237d7a4b859884ec9 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -95,12 +97,12 @@ string HloOpcodeString(HloOpcode opcode) { return "less-than-or-equal-to"; case HloOpcode::kLog: return "log"; - case HloOpcode::kLogicalAnd: - return "logical-and"; - case HloOpcode::kLogicalOr: - return "logical-or"; - case HloOpcode::kLogicalNot: - return "logical-not"; + case HloOpcode::kAnd: + return "and"; + case HloOpcode::kOr: + return "or"; + case HloOpcode::kNot: + return "not"; case HloOpcode::kLt: return "less-than"; case HloOpcode::kMap: @@ -147,6 +149,12 @@ string HloOpcodeString(HloOpcode opcode) { return "select"; case HloOpcode::kSend: return "send"; + case HloOpcode::kShiftLeft: + return "shift-left"; + case HloOpcode::kShiftRightArithmetic: + return "shift-right-arithmetic"; + case HloOpcode::kShiftRightLogical: + return "shift-right-logical"; case HloOpcode::kSign: return "sign"; case HloOpcode::kSin: @@ -172,6 +180,89 @@ string HloOpcodeString(HloOpcode opcode) { } } +StatusOr StringToHloOpcode(const string& opcode_name) { + static auto* opcode_map = new tensorflow::gtl::FlatMap( + {{"abs", HloOpcode::kAbs}, + {"add", HloOpcode::kAdd}, + {"batch-norm-training", HloOpcode::kBatchNormTraining}, + {"batch-norm-inference", HloOpcode::kBatchNormInference}, + {"batch-norm-grad", HloOpcode::kBatchNormGrad}, + {"bitcast", HloOpcode::kBitcast}, + {"broadcast", HloOpcode::kBroadcast}, + {"call", HloOpcode::kCall}, + {"clamp", HloOpcode::kClamp}, + {"concatenate", HloOpcode::kConcatenate}, + {"constant", HloOpcode::kConstant}, + {"convert", HloOpcode::kConvert}, + {"convolution", HloOpcode::kConvolution}, + {"cosine", HloOpcode::kCos}, + {"cross-replica-sum", HloOpcode::kCrossReplicaSum}, + {"custom-call", HloOpcode::kCustomCall}, + {"copy", HloOpcode::kCopy}, + {"divide", HloOpcode::kDivide}, + {"dot", HloOpcode::kDot}, + {"dynamic-slice", HloOpcode::kDynamicSlice}, + {"dynamic-update-slice", HloOpcode::kDynamicUpdateSlice}, + {"equal-to", HloOpcode::kEq}, + {"exponential", HloOpcode::kExp}, + {"floor", HloOpcode::kFloor}, + {"ceil", HloOpcode::kCeil}, + {"fusion", HloOpcode::kFusion}, + {"greater-than-or-equal-to", HloOpcode::kGe}, + {"get-tuple-element", HloOpcode::kGetTupleElement}, + {"greater-than", HloOpcode::kGt}, + {"index", HloOpcode::kIndex}, + {"infeed", HloOpcode::kInfeed}, + {"is-finite", HloOpcode::kIsFinite}, + {"less-than-or-equal-to", HloOpcode::kLe}, + {"log", HloOpcode::kLog}, + {"and", HloOpcode::kAnd}, + {"or", HloOpcode::kOr}, + {"not", HloOpcode::kNot}, + {"less-than", HloOpcode::kLt}, + {"map", HloOpcode::kMap}, + {"maximum", HloOpcode::kMaximum}, + {"minimum", HloOpcode::kMinimum}, + {"multiply", HloOpcode::kMultiply}, + {"not-equal-to", HloOpcode::kNe}, + {"negate", HloOpcode::kNegate}, + {"outfeed", HloOpcode::kOutfeed}, + {"pad", HloOpcode::kPad}, + {"parameter", HloOpcode::kParameter}, + {"power", HloOpcode::kPower}, + {"recv", HloOpcode::kRecv}, + {"reduce", HloOpcode::kReduce}, + {"reduce-precision", HloOpcode::kReducePrecision}, + {"reduce-window", HloOpcode::kReduceWindow}, + {"remainder", HloOpcode::kRemainder}, + {"reshape", HloOpcode::kReshape}, + {"reverse", HloOpcode::kReverse}, + {"rng", HloOpcode::kRng}, + {"round-nearest-afz", HloOpcode::kRoundNearestAfz}, + {"select-and-scatter", HloOpcode::kSelectAndScatter}, + {"select", HloOpcode::kSelect}, + {"send", HloOpcode::kSend}, + {"shift-left", HloOpcode::kShiftLeft}, + {"shift-right-arithmetic", HloOpcode::kShiftRightArithmetic}, + {"shift-right-logical", HloOpcode::kShiftRightLogical}, + {"sign", HloOpcode::kSign}, + {"sine", HloOpcode::kSin}, + {"slice", HloOpcode::kSlice}, + {"sort", HloOpcode::kSort}, + {"subtract", HloOpcode::kSubtract}, + {"tanh", HloOpcode::kTanh}, + {"trace", HloOpcode::kTrace}, + {"transpose", HloOpcode::kTranspose}, + {"tuple", HloOpcode::kTuple}, + {"update", HloOpcode::kUpdate}, + {"while", HloOpcode::kWhile}}); + auto it = opcode_map->find(opcode_name); + if (it == opcode_map->end()) { + return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); + } + return it->second; +} + bool HloOpcodeIsComparison(HloOpcode opcode) { switch (opcode) { case HloOpcode::kGe: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 7b23249640b0dcfdd510caf27bf57bb1f2f6850e..4593df671e34b1ec1f6e388439df37adf63b621f 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -62,9 +63,9 @@ enum class HloOpcode { kIsFinite, kLe, kLog, - kLogicalAnd, - kLogicalNot, - kLogicalOr, + kAnd, + kNot, + kOr, kLt, kMap, kMaximum, @@ -88,6 +89,9 @@ enum class HloOpcode { kSelect, kSelectAndScatter, kSend, + kShiftLeft, + kShiftRightArithmetic, + kShiftRightLogical, kSign, kSin, kSlice, @@ -104,6 +108,9 @@ enum class HloOpcode { // Returns a string representation of the opcode. string HloOpcodeString(HloOpcode opcode); +// Returns a string representation of the opcode. +StatusOr StringToHloOpcode(const string& opcode_name); + inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { return os << HloOpcodeString(opcode); } diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc index 0682434bfbac42ac3839c7066f167b7505dfdd0a..6ea0f127d53404af9514820b36a97bb0526aa5f9 100644 --- a/tensorflow/compiler/xla/service/inliner.cc +++ b/tensorflow/compiler/xla/service/inliner.cc @@ -90,8 +90,12 @@ Status InlinerVisitor::HandleMap( // different than the map shape. Hence, a broadcast is needed, else the // cloned operand with new shape and operands work. if (root.opcode() != HloOpcode::kConstant) { + std::vector params; + for (int64 o = 0; o < root.operands().size(); o++) { + params.push_back(operands[root.operand(o)->parameter_number()]); + } HloInstruction* placed_instruction = computation_->AddInstruction( - root.CloneWithNewOperands(map->shape(), operands)); + root.CloneWithNewOperands(map->shape(), params)); TF_RETURN_IF_ERROR( computation_->ReplaceInstruction(map, placed_instruction)); } else { diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 9d845c5545680f1a5389fa89af02eb723203a5f7..7aa1c7c8358318d02a000d968a2672123400ad6e 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -108,5 +108,44 @@ TEST_F(InlinerTest, MapConstant) { LiteralTestUtil::ExpectEqual(*result, *expected); } +TEST_F(InlinerTest, MapSubtractOppositeOrder) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + // Note that the parameter ordinals are in the opposite order to their + // position as operands + auto max_builder = HloComputation::Builder(TestName()); + auto param1 = max_builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "x")); + auto param2 = max_builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "y")); + max_builder.AddInstruction(HloInstruction::CreateBinary( + param1->shape(), HloOpcode::kSubtract, param1, param2)); + auto max_f32 = max_builder.Build(); + + auto builder = HloComputation::Builder("MapSubFunction"); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewModule(); + hlo_module->AddEmbeddedComputation(std::move(max_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + Inliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), + op::Subtract(rhs, lhs)); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto expected = Literal::CreateR1({3, 1, -1, -3}); + LiteralTestUtil::ExpectEqual(*result, *expected); +} + + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 177d2e2a9399c9cc8befff913b38fd286fa9e9fc..7e46d79ba41cc27894de892c100d5e71eb3153f1 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -53,9 +53,9 @@ namespace xla { case HloOpcode::kInfeed: case HloOpcode::kIsFinite: case HloOpcode::kLe: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalNot: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -69,6 +69,9 @@ namespace xla { case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSlice: @@ -203,6 +206,9 @@ bool InstructionFusion::CanFuseOnAllPaths( } StatusOr InstructionFusion::Run(HloModule* module) { + VLOG(2) << "Before instruction fusion:"; + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; module_ = module; for (auto* computation : module->MakeNonfusionComputations()) { @@ -371,6 +377,10 @@ StatusOr InstructionFusion::Run(HloModule* module) { } } } + + VLOG(2) << "After instruction fusion:"; + XLA_VLOG_LINES(2, module->ToString()); + return changed; } diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index c8d02834f43a747980d084be37602bc56db74b98..93ea2f736742eab06ee0d7e881ee7c51daee9878 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -88,7 +88,7 @@ StatusOr> InterpreterCompiler::Compile( StatusOr>> InterpreterCompiler::Compile( std::vector> /*hlo_modules*/, - std::vector /*stream_execs*/) { + std::vector> /*stream_execs*/) { return tensorflow::errors::Unimplemented( "Compilation of multiple HLO modules is not supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index 13db38ab60a07bdf476227c9b9e818dfe2cdcc6c..cfdc9b6256569b0137784b0d1db846a5f2339a5d 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -49,7 +49,8 @@ class InterpreterCompiler : public Compiler { StatusOr>> Compile( std::vector> hlo_modules, - std::vector stream_exec) override; + std::vector> + stream_exec) override; StatusOr>> CompileAheadOfTime(std::vector> hlo_modules, diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 8fd330fda715276894c5e4e5c9fc7f8c4416105d..2058706f1120238f63c06c8dcac79b8487888df5 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1180,8 +1180,6 @@ Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, // to match the layout of its corresponding fusion instruction operand. Also, // set the layout of the fused root to match the layout of the fusion // instruction itself. -// Fused GetTupleElement requires a layout so that TBAA metadata for the tuple -// element array pointer load can be added. Status SetFusionLayouts(HloInstruction* fusion) { TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion); for (auto* fused_instruction : fusion->fused_instructions()) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 86817b05f52f254184c53d9bcb6dd8ca14a7d39b..075d4a1ab5e5f39394ade393d21525ca3e97136e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -45,6 +45,7 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", @@ -92,7 +93,6 @@ cc_library( deps = [ ":ir_array", ":llvm_loop", - ":ops", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -111,7 +111,7 @@ cc_library( ":ir_array", ":llvm_util", ":loop_emitter", - ":ops", + ":tuple_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -127,6 +127,23 @@ cc_library( name = "ops", srcs = ["ops.cc"], hdrs = ["ops.h"], + deps = [ + ":fused_ir_emitter", + ":ir_array", + ":llvm_util", + ":loop_emitter", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", + "//tensorflow/compiler/xla/service/gpu:partition_assignment", + ], +) + +cc_library( + name = "tuple_ops", + srcs = ["tuple_ops.cc"], + hdrs = ["tuple_ops.h"], deps = [ ":ir_array", ":llvm_util", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 5e28e37600c18a351e8647d48119f073277f56e1..bdddc232ef74dfa37e2d5cc780b0fe11e7bc8e76 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -92,7 +92,16 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, llvm::MDNode* AliasAnalysis::GetAliasDomain() { llvm::MDBuilder metadata_builder(*context_); if (alias_domain_ == nullptr) { - alias_domain_ = metadata_builder.createAnonymousAliasScopeDomain(); + // We use createAliasScopeDomain rather than createAnonymousAliasScopeDomain + // so that when functions get inlined, we continue using the one domain, + // rather than duplicating it (and thus having two AA domains in one + // function). + // + // A side-effect of this is that if you ever compile two HLO modules in the + // same LLVM module, they'll have the same alias scope domain. This isn't a + // problem because the two HLO modules will never interact with one another. + alias_domain_ = + metadata_builder.createAliasScopeDomain("XLA global AA domain"); } return alias_domain_; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 7d1fad753e0d94f7d88b824ed57d52890a48b1dd..d286c49d6868c91026c8901b7871a322dabd38ec 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index e36c791c1a52f4e0699cc15ef913fbd2bdcca557..6a00a565c6d23aa8cd5f4e17621de8ca99dd1c5d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -268,8 +268,6 @@ llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder, name); llvm::LoadInst* load = ir_builder->CreateLoad(element_address); - llvm_ir::SetTbaaForInstruction(load, GetShape(), - /*is_pointer_to=*/false); AnnotateLoadStoreInstructionWithMetadata(load); return load; } @@ -278,8 +276,6 @@ void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, llvm::IRBuilder<>* ir_builder) const { llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder); llvm::StoreInst* store = ir_builder->CreateStore(value, element_address); - llvm_ir::SetTbaaForInstruction(store, GetShape(), - /*is_pointer_to=*/false); AnnotateLoadStoreInstructionWithMetadata(store); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 9498d40214f9e257359a48394d6c44585ecb4bff..8e188e7ae848b093abb2f7ba84b36413d397f7c8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include +#include #include #include "llvm/IR/MDBuilder.h" @@ -25,9 +26,12 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -398,13 +402,6 @@ void EmitLogging(const char* tag, llvm::Value* value, {ir_builder->getInt64(tensorflow::bit_cast(tag)), value}); } -void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, - bool is_pointer_to) { - // TODO(b/62903316): TBAA metadata causes LLVM to miscompile generated code, - // most likely because the generated metadata is incorrect. Disable TBAA - // metadata while we resolve this. -} - void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) { llvm::LLVMContext& context = load->getContext(); llvm::Type* int64_ty = llvm::Type::getInt64Ty(context); @@ -582,5 +579,23 @@ std::map MergeMetadata( return result; } +Status DumpIRToDirectory(const string& directory_name, + const string& hlo_module_name, + const llvm::Module& llvm_module, bool optimized) { + string safe_file_name_base = SanitizeFileName(hlo_module_name); + string ir_file_name = tensorflow::io::JoinPath( + directory_name, + tensorflow::strings::StrCat("ir-", safe_file_name_base, "-", + optimized ? "with" : "no", "-opt.ll")); + + std::unique_ptr f; + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->RecursivelyCreateDir(directory_name)); + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->NewWritableFile(ir_file_name, &f)); + TF_RETURN_IF_ERROR(f->Append(DumpModuleToString(llvm_module))); + return f->Close(); +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index ab8ac5e745dae3649b9d1cc62424aaaac50b6360..7a7d14da1eb62ab3d6401d2eff64c301c93a3806 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -227,12 +227,6 @@ llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* ir_builder); -// Adds TBAA metadata to a load or store instruction using the given shape as -// it's type. The is_pointer_to parameter is used to indicate whether or not -// this instruction loads or stores a pointer to an array. -void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, - bool is_pointer_to); - // Adds alignment metadata to a load instruction using the given alignment. // The alignment refers to the result of the load, not the load itself. void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment); @@ -273,6 +267,15 @@ std::map MergeMetadata( llvm::LLVMContext* context, const std::map& a, const std::map& b); +// Dumps out `llvm_module` to a file in the directory named `directory_name`, +// creating the directory if necessary. A sanitized version of +// `hlo_module_name` is incorporated into the file name. If `optimized` is true +// then a suffix of "-with-opt.ll" is used, else a suffix of "-no-opt.ll" is +// used. +Status DumpIRToDirectory(const string& directory_name, + const string& hlo_module_name, + const llvm::Module& llvm_module, bool optimized); + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 8bba1776d19005292da48705df2436b6f30e0f2d..6fa4cd08c9e0ac30b83c0e2b49d98d930c2e15df 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index ac562e231c8f56184363d6e186c18847d67435ce..34899b7400464e4f4f97d301f35ed3b7b083bca1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -14,86 +14,167 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" - -#include -#include -#include - -#include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" namespace xla { namespace llvm_ir { -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder) { - CHECK(ShapeUtil::IsScalar(pred.GetShape())); - - llvm::LoadInst* pred_value = - ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = ir_builder->CreateICmpNE( - pred_value, - llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, ir_builder), 0), - "boolean_predicate"); - - VLOG(2) << "HandleSelect for tuple:"; - VLOG(2) << " pred_value: " << DumpToString(*pred_value); - VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); - - for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - std::vector element_index = {ir_builder->getInt64(0), - ir_builder->getInt64(i)}; - llvm::Value* on_true_element_address = - ir_builder->CreateInBoundsGEP(on_true, element_index); - llvm::Value* on_true_element = ir_builder->CreateLoad( - on_true_element_address, - tensorflow::strings::Printf("on_true_element_%d", i).c_str()); - llvm::Value* on_false_element_address = - ir_builder->CreateInBoundsGEP(on_false, element_index); - llvm::Value* on_false_element = ir_builder->CreateLoad( - on_false_element_address, - tensorflow::strings::Printf("on_false_element_%d", i).c_str()); - - llvm::Value* output_element_address = - ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); - ir_builder->CreateStore( - ir_builder->CreateSelect( - pred_cond, on_true_element, on_false_element, - tensorflow::strings::Printf("select_output_element_%d", i).c_str()), - output_element_address); - } +bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, + const BufferAssignment& assignment) { + CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode()); + const HloInstruction* operand = dynamic_update_slice->operand(0); + return assignment.HasTopLevelAllocation(dynamic_update_slice) && + assignment.HasTopLevelAllocation(operand) && + assignment.SharesTopLevelSlice(dynamic_update_slice, operand); } -void EmitTuple(IrArray tuple, - tensorflow::gtl::ArraySlice operands, - llvm::IRBuilder<>* ir_builder) { - for (size_t i = 0; i < operands.size(); ++i) { - ir_builder->CreateStore( - ir_builder->CreatePointerCast(operands[i], - PrimitiveTypeToIrType(TUPLE, ir_builder)), - ir_builder->CreateInBoundsGEP( - tuple.GetBasePointer(), - {ir_builder->getInt64(0), ir_builder->getInt64(i)})); +// Shared implementation of EmitDynamicUpdateSliceInPlace and +// EmitFusedDynamicUpdateSliceInPlace. +// +// Emits a sequential loop if launch_dimensions is null. +static Status EmitDynamicUpdateSliceInPlaceImpl( + const Shape& update_shape, const ElementGenerator& start_indices_generator, + ElementGenerator update_array_generator, const IrArray& output_array, + const gpu::LaunchDimensions* launch_dimensions, + tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) { + const Shape& output_shape = output_array.GetShape(); + + // Read start indices from start_indices_generator. + const int64 rank = ShapeUtil::Rank(output_shape); + IrArray::Index start_index(rank); + for (int64 i = 0; i < rank; ++i) { + IrArray::Index dim_index({ir_builder->getInt64(i)}); + TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); } + + auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { + // Calculate output_index, where we'll write the value from update. For + // each dimension, + // + // output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size. + // + IrArray::Index output_index(rank); + for (int64 i = 0; i < rank; ++i) { + llvm::Value* dim_size = llvm::ConstantInt::get( + update_index[i]->getType(), output_shape.dimensions(i)); + llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast( + start_index[i], update_index[i]->getType()); + output_index[i] = ir_builder->CreateURem( + ir_builder->CreateAdd(start_index0, update_index[i]), dim_size); + } + + // Do output[output_index] = update[update_index]. + TF_ASSIGN_OR_RETURN(llvm::Value * update_data, + update_array_generator(update_index)); + output_array.EmitWriteArrayElement(output_index, update_data, ir_builder); + return Status::OK(); + }; + + if (launch_dimensions != nullptr) { + return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape, + *launch_dimensions, ir_builder) + .EmitLoop(name); + } + return LoopEmitter(loop_body_emitter, update_shape, ir_builder) + .EmitLoop(name); +} + +Status EmitDynamicUpdateSliceInPlace( + tensorflow::gtl::ArraySlice operand_arrays, + const IrArray& output_array, tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder) { + VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; + + // No need to use operand_arrays[0], the input array of the + // dynamic-update-slice, because we know it aliases the op's output. + IrArray update_array = operand_arrays[1]; + IrArray start_indices_array = operand_arrays[2]; + Shape output_shape = output_array.GetShape(); + Shape update_shape = update_array.GetShape(); + + ElementGenerator start_indices_generator = [&](const IrArray::Index& index) { + return start_indices_array.EmitReadArrayElement(index, ir_builder); + }; + ElementGenerator update_array_generator = [&](const IrArray::Index& index) { + return update_array.EmitReadArrayElement(index, ir_builder); + }; + + return EmitDynamicUpdateSliceInPlaceImpl( + update_shape, start_indices_generator, update_array_generator, + output_array, /*launch_dimensions=*/nullptr, name, ir_builder); +} + +// Shared implementation for EmitFusedDynamicUpdateSliceInPlace and +// EmitParallelFusedDynamicUpdateSliceInPlace. +// +// Emits a sequential loop if launch_dimensions is null. +static Status EmitFusedDynamicUpdateSliceInPlaceImpl( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + const gpu::LaunchDimensions* launch_dimensions, + llvm::IRBuilder<>* ir_builder) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for " + << fusion->ToShortString(); + + auto* dynamic_update_slice = fusion->fused_expression_root(); + + const auto* update = dynamic_update_slice->operand(1); + const auto* start_indices = dynamic_update_slice->operand(2); + Shape update_shape = update->shape(); + + // Our in-place dynamic-update-slice implementation emits a loop over + // update_shape. To emit a cache-friendly loop, we need to know that shape's + // layout. + // + // update_shape is inside a fusion node -- it's never materialized in memory + // and thus doesn't have a layout. In this case we use the layout of the + // fusion node for iteration, since that corresponds to the order in memory of + // the buffer we'll be writing to. + // + // (This isn't necessarily optimal; in some cases it might be faster to peek + // through the chain of ops that gives us the update operand and use the + // layout of its source buffer(s). But this is no worse than we do with + // fusion elsewhere.) + TF_RETURN_IF_ERROR( + LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape)); + + // Create element generators for update and start_indices. + FusedIrEmitter fused_emitter(fusion_operand_arrays, elemental_emitter); + TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); + ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); + ElementGenerator start_indices_generator = + fused_emitter.GetGenerator(start_indices); + + return EmitDynamicUpdateSliceInPlaceImpl( + update_shape, start_indices_generator, update_array_generator, + fusion_output_array, launch_dimensions, IrName(fusion), ir_builder); +} + +Status EmitFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + llvm::IRBuilder<>* ir_builder) { + return EmitFusedDynamicUpdateSliceInPlaceImpl( + fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, + /*launch_dimensions=*/nullptr, ir_builder); } -llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, - int alignment, llvm::Value* operand, - llvm::IRBuilder<>* ir_builder) { - llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP( - operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)}); - llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr); - SetTbaaForInstruction(src_buffer, target_shape, /*is_pointer_to=*/true); - SetAlignmentMetadataForLoad(src_buffer, alignment); - llvm::Type* element_type = ShapeToIrType(target_shape, ir_builder); - llvm::Value* ret_val = - ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo()); - return ret_val; +Status EmitParallelFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + const gpu::LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* ir_builder) { + return EmitFusedDynamicUpdateSliceInPlaceImpl( + fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, + &launch_dimensions, ir_builder); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/ops.h index 4e1d9d1080b3a5c8d8a09145f68bcff9d329c929..11e84d9cb5defbcb87a8f696d56c139686c960d8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.h @@ -13,67 +13,68 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/types.h" + +// Utilities related to emitting LLVM IR for various HLO ops. namespace xla { namespace llvm_ir { -// Selection among tuples is special in how it's lowered, because a tuple is not -// an HLO array. -// -// tuple_on_true tuple_on_false -// | | -// V V -// ------------------------ ------------------------ -// | address of element 0 | | address of element 0 | -// |----------------------| |----------------------| -// | address of element 1 | | address of element 1 | -// |----------------------| |----------------------| -// | address of element 2 | | address of element 2 | -// ------------------------ ------------------------ -// \ / -// \ / -// ---------- -// pred ---------> | select | -// ---------- -// | -// V -// output ----> ------------------------ -// | address of element 0 | -// |----------------------| -// | address of element 1 | -// |----------------------| -// | address of element 2 | -// ------------------------ +// Checks if we can emit code for the given DynamicUpdateSlice node that updates +// its input in place. Returns true if the dynamic-update-slice's +// array-to-be-updated and output share the same BufferAllocation::Slice. // -// Only the addresses are copied to the output. For each element, we emit a copy -// of the address from the corresponding element in either -// tuple_on_true or tuple_on_false: -// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder); +// dynamic_update_slice must be a DynamicUpdateSlice op. +bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, + const BufferAssignment& assignment); + +// Checks if the given fusion node is amenable to being implemented by +// EmitFusedDynamicUpdateSliceInPlace. +inline bool CanEmitFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, const BufferAssignment& assignment) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop && + fusion->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice && + CanUpdateDynamicSliceInPlace(fusion->fused_expression_root(), + assignment); +} + +// Emits IR for running the given dynamic-update-slice op in-place -- that is, +// where the input and output buffers share the same slice, so we can simply +// modify the input/output buffer without touching any of the other elements. +Status EmitDynamicUpdateSliceInPlace( + tensorflow::gtl::ArraySlice operand_arrays, + const IrArray& output_array, tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder); + +// Given a loop-fusion node whose root is a dynamic-update-slice op whose +// array-to-be-updated and output share the same buffer slice, emits +// (sequential) code for a fusion node that does the dynamic-update-slice in +// place. +Status EmitFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + llvm::IRBuilder<>* ir_builder); -// A tuple is an array of pointers, one for each operand. Each pointer points to -// the output buffer of its corresponding operand. -void EmitTuple(IrArray tuple, - tensorflow::gtl::ArraySlice operands, - llvm::IRBuilder<>* ir_builder); +// Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with +// the given launch dimensions. +Status EmitParallelFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + const gpu::LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* ir_builder); -// A tuple is an array of pointers, one for each operand. Each pointer points to -// the output buffer of its corresponding operand. A GetTupleElement instruction -// forwards the pointer to underlying tuple element buffer at the given index. -// Returns an llvm value representing a pointer to the tuple element buffer. -llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, - int alignment, llvm::Value* operand, - llvm::IRBuilder<>* ir_builder); } // namespace llvm_ir } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..6051cbfc6f6a2d3cc99740beda5dee03a9392bdd --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" + +#include +#include +#include + +#include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace llvm_ir { + +void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, + llvm::Value* on_false, llvm::IRBuilder<>* ir_builder) { + CHECK(ShapeUtil::IsScalar(pred.GetShape())); + + llvm::LoadInst* pred_value = + ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ir_builder->CreateICmpNE( + pred_value, + llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, ir_builder), 0), + "boolean_predicate"); + + VLOG(2) << "HandleSelect for tuple:"; + VLOG(2) << " pred_value: " << DumpToString(*pred_value); + VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); + + for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { + std::vector element_index = {ir_builder->getInt64(0), + ir_builder->getInt64(i)}; + llvm::Value* on_true_element_address = + ir_builder->CreateInBoundsGEP(on_true, element_index); + llvm::Value* on_true_element = ir_builder->CreateLoad( + on_true_element_address, + tensorflow::strings::Printf("on_true_element_%d", i).c_str()); + llvm::Value* on_false_element_address = + ir_builder->CreateInBoundsGEP(on_false, element_index); + llvm::Value* on_false_element = ir_builder->CreateLoad( + on_false_element_address, + tensorflow::strings::Printf("on_false_element_%d", i).c_str()); + + llvm::Value* output_element_address = + ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); + ir_builder->CreateStore( + ir_builder->CreateSelect( + pred_cond, on_true_element, on_false_element, + tensorflow::strings::Printf("select_output_element_%d", i).c_str()), + output_element_address); + } +} + +void EmitTuple(IrArray tuple, + tensorflow::gtl::ArraySlice operands, + llvm::IRBuilder<>* ir_builder) { + for (size_t i = 0; i < operands.size(); ++i) { + auto* store = ir_builder->CreateStore( + ir_builder->CreatePointerCast(operands[i], + PrimitiveTypeToIrType(TUPLE, ir_builder)), + ir_builder->CreateInBoundsGEP( + tuple.GetBasePointer(), + {ir_builder->getInt64(0), ir_builder->getInt64(i)})); + tuple.AnnotateLoadStoreInstructionWithMetadata(store); + } +} + +llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, + int alignment, llvm::Value* operand, + llvm::IRBuilder<>* ir_builder) { + llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP( + operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)}); + llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr); + + // Mark the loaded pointer as dereferenceable if we know its shape. + if (!ShapeUtil::IsOpaque(target_shape)) { + SetDereferenceableMetadataForLoad( + src_buffer, + ByteSizeOf(target_shape, src_buffer->getModule()->getDataLayout())); + } + SetAlignmentMetadataForLoad(src_buffer, alignment); + + llvm::Type* element_type = ShapeToIrType(target_shape, ir_builder); + llvm::Value* ret_val = + ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo()); + return ret_val; +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a75cdc815808fc3b9e8669dde1eddf995080f53d --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +// Utilities for emitting LLVM IR related to HLO tuples. + +namespace xla { +namespace llvm_ir { + +// Selection among tuples is special in how it's lowered, because a tuple is not +// an HLO array. +// +// tuple_on_true tuple_on_false +// | | +// V V +// ------------------------ ------------------------ +// | address of element 0 | | address of element 0 | +// |----------------------| |----------------------| +// | address of element 1 | | address of element 1 | +// |----------------------| |----------------------| +// | address of element 2 | | address of element 2 | +// ------------------------ ------------------------ +// \ / +// \ / +// ---------- +// pred ---------> | select | +// ---------- +// | +// V +// output ----> ------------------------ +// | address of element 0 | +// |----------------------| +// | address of element 1 | +// |----------------------| +// | address of element 2 | +// ------------------------ +// +// Only the addresses are copied to the output. For each element, we emit a copy +// of the address from the corresponding element in either +// tuple_on_true or tuple_on_false: +// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] +void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, + llvm::Value* on_false, llvm::IRBuilder<>* ir_builder); + +// A tuple is an array of pointers, one for each operand. Each pointer points to +// the output buffer of its corresponding operand. +void EmitTuple(IrArray tuple, + tensorflow::gtl::ArraySlice operands, + llvm::IRBuilder<>* ir_builder); + +// A tuple is an array of pointers, one for each operand. Each pointer points to +// the output buffer of its corresponding operand. A GetTupleElement instruction +// forwards the pointer to underlying tuple element buffer at the given index. +// Returns an llvm value representing a pointer to the tuple element buffer. +llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, + int alignment, llvm::Value* operand, + llvm::IRBuilder<>* ir_builder); +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 3235081f83f53e2850efa2c6ccd221318fa0c58b..d4d35da9d636e6e204f36850e7987327ab258696 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -91,7 +91,7 @@ int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy, StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, - const Shape* result_layout, int device_ordinal, bool has_hybrid_result) { + const Shape* result_layout, int device_ordinal) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(computation)); VersionedComputationHandle versioned_handle = @@ -133,8 +133,7 @@ StatusOr> LocalService::CompileExecutable( } TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, &execution_options, - has_hybrid_result)); + CreateModuleConfig(*program_shape, argument_layouts, &execution_options)); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, execute_backend_->stream_executor(device_ordinal)); diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index f2bfb960f4307d12556337f76cfd6ea7a38b6e20..52c4346385eb663baa6e7579d7b3883ba084205b 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -45,7 +45,7 @@ class LocalService : public Service { StatusOr> CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, - const Shape* result_layout, int device_ordinal, bool has_hybrid_result); + const Shape* result_layout, int device_ordinal); private: explicit LocalService(const ServiceOptions& options, diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 069f85af721228c8f5d40cf243eea7f1e5173c62..a0d08c288dbcc45e83a36ce7b094b04a9dbae532 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -23,7 +23,24 @@ namespace xla { string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { string root = prefix.empty() ? "name" : prefix.ToString(); - int* count = &(generated_names_[root]); + + // Strip away numeric suffix (if any). Only recognize separator if it is in + // the middle of the name. + size_t separator_index = root.rfind(separator_); + if (separator_index != string::npos && (separator_index > 0) && + (separator_index < root.size() - 1)) { + string after_suffix = root.substr(separator_index + 1); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + // Remove numeric suffix from root. + root = root.substr(0, separator_index); + // Update count to at least the numeric suffix value to avoid future + // colisions with this name. + generated_names_[root] = std::max(generated_names_[root], numeric_suffix); + } + } + + int64* count = &(generated_names_[root]); if (*count == 0) { *count = 1; return root; @@ -31,9 +48,6 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { tensorflow::strings::StrAppend(&root, separator_, *count); // Increment lookup under old 'root' name. (*count)++; - // Initialize count under new 'root' name. - count = &(generated_names_[root]); - *count = 1; return root; } } diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index b0944adbc1d98fd88c550cc8b53cf399e43535e6..ed379b52258463b960dea788721c2c4325ef0260 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -43,7 +43,7 @@ class NameUniquer { // Map from name prefix to the number of names generated using that prefix // so far. - std::unordered_map generated_names_; + std::unordered_map generated_names_; TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); }; diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9f0747a6e2175a968d8f3661ac51512009e86f29 --- /dev/null +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/name_uniquer.h" + +#include +#include +#include + +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class NameUniquerTest : public ::testing::Test {}; + +TEST_F(NameUniquerTest, SimpleUniquer) { + NameUniquer uniquer; + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo__1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo__2", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar", uniquer.GetUniqueName("bar")); + EXPECT_EQ("foo__3", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar__1", uniquer.GetUniqueName("bar")); + EXPECT_EQ("qux", uniquer.GetUniqueName("qux")); +} + +TEST_F(NameUniquerTest, DifferentSeparator) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.2", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar", uniquer.GetUniqueName("bar")); + EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar")); +} + +TEST_F(NameUniquerTest, NumericSuffixes) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); + EXPECT_EQ("foo.55", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.55.1", uniquer.GetUniqueName("foo.55.1")); + EXPECT_EQ("foo.55.2", uniquer.GetUniqueName("foo.55.1")); + EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000")); + EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); + EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); + + // Separator is only recognized in the middle of the prefix. + EXPECT_EQ(".10", uniquer.GetUniqueName(".10")); + EXPECT_EQ(".10.1", uniquer.GetUniqueName(".10")); + EXPECT_EQ("foobar.", uniquer.GetUniqueName("foobar.")); + EXPECT_EQ("foobar..1", uniquer.GetUniqueName("foobar.")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 404fd3e6d7faaedd6cfba2996034f72d8f116d4e..0fb90230f2f39a841973361f63d17af579a1342b 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -48,23 +48,28 @@ namespace xla { namespace { -// Checks if an instruction can change its shape simply by adjusting metadata. -// This is the case if it is: -// -// - an instruction does not have any producers like Constants -// or Rng instruction, or is a scalar. -// -// Or -// -// - an reshape/transpose instruction with an operand that can trivially change -// its shape. -bool InstructionCanTriviallyChangeShape(const HloInstruction* instruction) { - // Reshape/Transposes are only trivial if their operand is trivial. - if (instruction->opcode() == HloOpcode::kReshape || - instruction->opcode() == HloOpcode::kTranspose) { - CHECK_EQ(instruction->operand_count(), 1); - return InstructionCanTriviallyChangeShape(instruction->operand(0)); - } +bool IsReshapeOrTranspose(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kReshape || + instruction->opcode() == HloOpcode::kTranspose; +} + +// Returns true iff `instruction` can change its shape simply by adjusting +// metadata. +bool CanTriviallyChangeShape(const HloInstruction* instruction) { + // NOTE: Technically a sequence of reshape(reshape(constant)) is also + // trivially reshapable, so we might be tempted to simply recurse if + // IsReshapeOrTranspose(instruction)==true. + // + // But it's not that simple. E.g. reshape(reshape(rng)) is only trivially + // reshapable if *all* instructions in the chain have user_count == 1. And + // reshape(scalar) isn't trivial at all if the reshape itself isn't scalar; we + // rely on implicit scalar broadcast for scalars to be trivial. In addition, + // these cases make it harder to maintain correctness of the UpdateOperand + // logic below. + // + // So don't handle these chains, unless you update the tests and code to deal + // with these properly. One idea is to add a pass immediately beforehand that + // collapses trivial runs of reshapes / transposes. // Scalars can operate with any shape. if (ShapeUtil::IsScalar(instruction->shape())) { @@ -93,9 +98,8 @@ HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( const HloInstruction* hlo) { for (HloInstruction* operand : hlo->operands()) { if (!ShapeUtil::IsScalar(operand->shape()) && - ((operand->opcode() == HloOpcode::kReshape || - operand->opcode() == HloOpcode::kTranspose) && - !InstructionCanTriviallyChangeShape(operand->operand(0)))) { + IsReshapeOrTranspose(operand) && + !CanTriviallyChangeShape(operand->operand(0))) { VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " << hlo->ToStringNoMetadata() << ":\n\t" << operand->ToStringNoMetadata(); @@ -122,28 +126,15 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { } } -// Returns true if an elementwise operation has all operands that can easily -// change shape. Operands can easily change shape if they are all -// reshapes/transposes to and from the same shape. Additionally, operands like -// constant, rng, and any scalar change shape with only an adjustment of -// metadata. -bool IsElementwiseOfEquivalentReshapesOrTransposes( - const HloInstruction* instruction) { - const auto& operands = instruction->operands(); - HloInstruction* first_reshape_operand = - FirstNonScalarAndNonTrivialReshapeOperand(instruction); - // If there are no non-trivial reshapes or transposes, then there is nothing - // to sink below the elementwise operation. - if (!first_reshape_operand) { - return false; - } - VLOG(3) << "** Checking whether instruction is an elementwise operation of " - "equivalent reshapes/transposes: " +// Returns true if all operands of `instruction` can easily change shape. +// Operands can easily change shape if they are all reshapes/transposes to and +// from the same shape. Additionally, operands like constant, rng, and any +// scalar change shape with only an adjustment of metadata. +bool AllOperandsHaveEasyShapeChanges( + const HloInstruction* instruction, + const HloInstruction* first_reshape_operand) { + VLOG(3) << "** Checking whether all operands have easy shape changes: " << instruction->ToStringNoMetadata(); - bool result = (instruction->user_count() > 0 || - instruction == instruction->parent()->root_instruction()) && - instruction->IsElementwise() && !operands.empty(); - // Check whether all operands: // 0. Have the same dimensions as the output -- if not, it may be // implicitly broadcast, which can confound the movement's @@ -155,66 +146,117 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes( // or // 2. Are one of kConstant, kRng, and scalars that can change shape // trivially, - if (result) { - for (auto& operand : operands) { - if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { - VLOG(5) << "Operand shape differs from output shape; may be " - "implicitly broadcast, so preventing " - "movement\n\toperand: " - << operand->ToStringNoMetadata() - << "\n\tinstruction: " << instruction->ToStringNoMetadata(); - result = false; - break; - } - - if (AreEquivalentReshapes(first_reshape_operand, operand)) { - VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " - << first_reshape_operand->ToStringNoMetadata() - << "\n\toperand: " << operand->ToStringNoMetadata(); - continue; - } + for (const HloInstruction* operand : instruction->operands()) { + if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToStringNoMetadata() + << "\n\tinstruction: " << instruction->ToStringNoMetadata(); + return false; + } - if (InstructionCanTriviallyChangeShape(operand)) { - VLOG(5) << "Operand can trivially change shape: " - << operand->ToStringNoMetadata(); - continue; - } + if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\toperand: " << operand->ToStringNoMetadata(); + continue; + } - // TODO(someone): Look into supporting general ops for the operands as - // well. - VLOG(5) << "Operand is neither equalivant to the first Reshape operand" - "nor can trivially change shape: " + if (CanTriviallyChangeShape(operand)) { + VLOG(5) << "Operand can trivially change shape: " << operand->ToStringNoMetadata(); - result = false; - break; + continue; } + + // TODO(someone): Look into supporting general ops for the operands as + // well. + VLOG(5) << "Operand is neither equalivant to the first Reshape operand" + "nor can trivially change shape: " + << operand->ToStringNoMetadata(); + return false; } - VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for " - << instruction->ToStringNoMetadata() << ": " << result; - return result; + VLOG(3) << "All operands have easy shape changes: " + << instruction->ToStringNoMetadata(); + return true; +} + +// This function is called once we've decided to sink reshape/transpose operands +// across an instruction. It returns an updated `operand` with a shape that +// plays nicely with `new_operand_shape`; either it has the same shape (of the +// correct type), or it is a scalar that may be implicitly broadcast. +HloInstruction* UpdateOperand(HloComputation* computation, + const HloInstruction* first_reshape_operand, + const Shape& new_operand_shape, + HloInstruction* operand) { + const PrimitiveType element_type = operand->shape().element_type(); + const Shape new_shape = + ShapeUtil::ChangeElementType(new_operand_shape, element_type); + + switch (operand->opcode()) { + case HloOpcode::kConstant: { + if (first_reshape_operand->opcode() == HloOpcode::kReshape) { + VLOG(5) << "Adding reshape to kConstant operand"; + return computation->AddInstruction( + HloInstruction::CreateReshape(new_shape, operand)); + } else { + CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose); + VLOG(5) << "Adding transpose to kConstant operand"; + std::vector inverse_permutation = + InversePermutation(first_reshape_operand->dimensions()); + return computation->AddInstruction(HloInstruction::CreateTranspose( + new_shape, operand, inverse_permutation)); + } + } + case HloOpcode::kRng: { + CHECK_EQ(operand->user_count(), 1); + VLOG(5) << "Cloning kRng operand with new shape"; + return computation->AddInstruction( + operand->CloneWithNewOperands(new_shape, operand->operands())); + } + case HloOpcode::kReshape: + case HloOpcode::kTranspose: { + VLOG(5) << "Using existing operand of kReshape or kTranspose"; + return operand->mutable_operand(0); + } + default: + LOG(FATAL) << "Unexpected operand opcode during update: " << operand; + } } // Try to sink any reshape or transpose operands of `instruction` across it. We // do so if `instruction` is elementwise and all operands are either equivalent -// reshapes/transposes or are trivially reshapable. Note that no move is -// performend if there is no nontrivial reshapes/transposes. +// reshapes/transposes or are trivially reshapable. StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, HloInstruction* instruction) { - if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { + // Only perform sinks for live elementwise instructions with operands. + const bool is_dead = instruction->user_count() == 0 && + instruction != computation->root_instruction(); + if (!instruction->IsElementwise() || instruction->operands().empty() || + is_dead) { return false; } - HloInstruction* old_reshape = + // Only perform sinks if there are any nontrivial reshape/transpose operands. + const HloInstruction* first_reshape_operand = FirstNonScalarAndNonTrivialReshapeOperand(instruction); - TF_RET_CHECK(old_reshape != nullptr); - Shape new_elementwise_shape = old_reshape->operand(0)->shape(); + if (!first_reshape_operand) { + return false; + } + + // Only perform sinks if all operands can easily change shape. + if (!AllOperandsHaveEasyShapeChanges(instruction, first_reshape_operand)) { + return false; + } - VLOG(3) << "** Trying to sink reshape or transpose: " - << instruction->ToStringNoMetadata() - << "\n\told reshape: " << old_reshape->ToStringNoMetadata() - << "\n\tnew elementwise shape: " - << ShapeUtil::HumanString(new_elementwise_shape); + // At this point we've decided to sink reshape/transpose operands. + const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape(); + VLOG(3) << "** Sinking reshape or transpose: " + << instruction->ToStringNoMetadata() << "\n\tfirst reshape operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\tnew operand shape: " + << ShapeUtil::HumanString(new_operand_shape); auto operands = instruction->operands(); for (size_t i = 0; i < operands.size(); ++i) { @@ -224,55 +266,19 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, if (ShapeUtil::IsScalar(operands[i]->shape())) { continue; } - PrimitiveType element_type = operands[i]->shape().element_type(); - switch (operands[i]->opcode()) { - case HloOpcode::kConstant: { - if (old_reshape->opcode() == HloOpcode::kReshape) { - VLOG(3) << "Creating reshape for kConstant operand " << i << ": " - << operands[i]->ToStringNoMetadata(); - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateReshape( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i])); - } else { - TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose); - std::vector inverse_permutation = - InversePermutation(old_reshape->dimensions()); - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateTranspose( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i], inverse_permutation)); - } - break; - } - case HloOpcode::kRng: { - CHECK_EQ(operands[i]->user_count(), 1); - operands[i] = instruction->parent()->AddInstruction( - operands[i]->CloneWithNewOperands( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i]->operands())); - break; - } - case HloOpcode::kReshape: - case HloOpcode::kTranspose: - operands[i] = operands[i]->mutable_operand(0); - break; - default: - LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " - "transposes."; - } + VLOG(3) << "Updating operand #" << i << ": " + << operands[i]->ToStringNoMetadata(); + operands[i] = UpdateOperand(computation, first_reshape_operand, + new_operand_shape, operands[i]); } if (HloOpcode::kFusion == instruction->opcode()) { // Here we already know `instruction` is elementwise, and no operand is - // implicit broadcast as if it were the operands would not be equivalent - // reshapes, so all the fused instructions have the same dimensions. + // implicit broadcast as if it were the operands would not have easy shape + // changes, so all the fused instructions have the same dimensions. for (const auto& fused_instruction : instruction->fused_instructions()) { Shape* shape = fused_instruction->mutable_shape(); - *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); - *shape->mutable_layout() = new_elementwise_shape.layout(); + *shape->mutable_dimensions() = new_operand_shape.dimensions(); + *shape->mutable_layout() = new_operand_shape.layout(); } } HloInstruction* new_elementwise = @@ -284,12 +290,12 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, // // In this case, convert' should have the same element type as // `convert` and the same dimensions as operands[0]. - ShapeUtil::ChangeElementType(new_elementwise_shape, + ShapeUtil::ChangeElementType(new_operand_shape, instruction->shape().element_type()), operands)); std::unique_ptr new_reshape; - switch (old_reshape->opcode()) { + switch (first_reshape_operand->opcode()) { case HloOpcode::kReshape: VLOG(3) << "Creating new reshape for new elementwise op: " << new_elementwise->ToStringNoMetadata(); @@ -297,8 +303,9 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, HloInstruction::CreateReshape(instruction->shape(), new_elementwise); break; case HloOpcode::kTranspose: - new_reshape = HloInstruction::CreateTranspose( - instruction->shape(), new_elementwise, old_reshape->dimensions()); + new_reshape = + HloInstruction::CreateTranspose(instruction->shape(), new_elementwise, + first_reshape_operand->dimensions()); break; default: LOG(FATAL) << "Bad opcode"; @@ -312,6 +319,8 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, StatusOr ReshapeMover::Run(HloModule* module) { bool changed = false; + VLOG(2) << "Pre ReshapeMover HLO:"; + XLA_VLOG_LINES(2, module->ToString()); for (auto* comp : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool did_change, @@ -319,6 +328,8 @@ StatusOr ReshapeMover::Run(HloModule* module) { changed |= did_change; } } + VLOG(2) << "Post ReshapeMover HLO:"; + XLA_VLOG_LINES(2, module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index b7e0a46939a10b3376758109214c9722976f50e0..1f59e3b3147facb6f2fae00d6c810bf54d560e5c 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -26,7 +26,7 @@ namespace xla { // them inputward also. class ReshapeMover : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "reshape-motion"; } + tensorflow::StringPiece name() const override { return "reshape-mover"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index a81d3f4eb344510b3973aa46a576d11f613bc404..aac8638a54f744f0c230ec6c5ca071c1daf45ab2 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using ReshapeMoverTest = HloTestBase; +using ReshapeMoverTest = HloVerifiedTestBase; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { HloComputation::Builder builder(TestName()); @@ -50,13 +50,12 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); @@ -89,13 +88,12 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); @@ -115,13 +113,12 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -142,12 +139,11 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, param1))); @@ -193,21 +189,19 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param2)); builder.AddInstruction(HloInstruction::CreateTernary( - ShapeUtil::MakeShape(PRED, {2, 3}), HloOpcode::kSelect, const0, reshape1, - reshape2)); + root_shape, HloOpcode::kSelect, const0, reshape1, reshape2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(const0, reshape1, reshape2)); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(op::Reshape(const0), param1, param2))); - EXPECT_EQ(const0->shape().DebugString(), + EXPECT_EQ(root_shape.DebugString(), computation->root_instruction()->shape().DebugString()); } @@ -228,17 +222,16 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); - auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, root_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, param1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); @@ -260,7 +253,7 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { // trivial reshapes. TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { HloComputation::Builder builder(TestName()); - auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); + auto root_shape = ShapeUtil::MakeShape(F32, {3, 2}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = @@ -272,18 +265,17 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); auto pred = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(PRED, {1, 3, 1, 2}), "pred")); + 0, ShapeUtil::MakeShape(PRED, {3, 2}), "pred")); builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, pred, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); @@ -323,13 +315,12 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), const1)); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, op::Reshape(const1)))); @@ -337,6 +328,48 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { computation->root_instruction()->shape().DebugString()); } +// For a graph that looks like: +// +// +- reshape0 - param0 (shape A) +// | +// +- reshape1 - const1 (shape B) +// | +// add +// +// There is 1 non-trivial reshape (reshape0). It's not clear whether reshape1 +// should be trivial or not; conceptually it's trivial, but handling it would +// complicate the rest of our logic. +// +// For now we treat it as non-trivial, so we verify that we don't sink the +// reshapes in this case. +TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 3}), "param0")); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({9, 8, 7}))); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); + + builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape0, reshape1)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(const1))); + + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(const1))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); +} + TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); @@ -351,15 +384,14 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Fusion(param0, param1))); @@ -386,14 +418,13 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(pred, param0, param1))); @@ -416,12 +447,11 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, param0, param1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); @@ -468,12 +498,11 @@ TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( constant->shape(), HloOpcode::kMultiply, constant, reshape)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Constant(), op::Reshape(param0))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Constant(), op::Reshape(param0))); @@ -517,15 +546,14 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, reshape2, reshape3)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Add(op::Reshape(param2), op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1))))); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 049ae91e9308dc8ab89db0328cf8098ca54ef098..0fbc2f2fec64917f5117dc5021c5e0a5b0f4367e 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -153,7 +153,7 @@ int ServiceOptions::intra_op_parallelism_threads() const { Service::Service(const ServiceOptions& options, std::unique_ptr execute_backend) : options_(options), execute_backend_(std::move(execute_backend)) { - CHECK(options_.number_of_replicas() > 0); + CHECK_GT(options_.number_of_replicas(), 0); if (execute_backend_) { if (execute_backend_->device_count() > 0) { CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas()) @@ -187,8 +187,9 @@ tensorflow::Status Service::Computation(const ComputationRequest* arg, *result->mutable_computation() = computation_tracker_.NewComputation(arg->name()); - VLOG(1) << Printf("Created new computation %s on service %p", - result->computation().ShortDebugString().c_str(), this); + VLOG(1) << Printf("Created new computation %s on service %p, name %s", + result->computation().ShortDebugString().c_str(), this, + arg->name().c_str()); return tensorflow::Status::OK(); } @@ -268,7 +269,7 @@ StatusOr> Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, bool has_hybrid_result) { + const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); auto* computation_layout = config->mutable_entry_computation_layout(); @@ -305,7 +306,6 @@ StatusOr> Service::CreateModuleConfig( } config->set_replica_count(options_.number_of_replicas()); - config->set_has_hybrid_result(has_hybrid_result); if (execution_options != nullptr) { config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); @@ -338,7 +338,7 @@ StatusOr>> Service::BuildExecutables( std::vector versioned_handles, std::vector> module_configs, Backend* backend, - std::vector executors) { + std::vector> executors) { VLOG(1) << Printf("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. @@ -615,31 +615,41 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); std::vector> all_arguments; - std::vector executors; + std::vector> all_executors; std::vector versioned_handles; std::vector> module_configs; std::vector computation_names; std::vector device_handles; - if (arg->requests_size() * options_.number_of_replicas() > + int num_requested_devices = + std::accumulate(arg->requests().begin(), arg->requests().end(), 0, + [](int a, const ExecuteRequest& r) -> int { + return a + r.execution_options().device_handles_size(); + }); + if (num_requested_devices * options_.number_of_replicas() > execute_backend_->device_count()) { return FailedPrecondition( "there are not enough stream executors to execute %d computations", - arg->requests_size()); + num_requested_devices); } for (int64 i = 0; i < arg->requests_size(); ++i) { // Get the stream executor for the i'th computation. This stream executor // is one of the executors to run the replicated computation. - if (!arg->requests(i).has_device_handle()) { + const ExecutionOptions& execution_options = + arg->requests(i).execution_options(); + if (execution_options.device_handles().empty()) { return FailedPrecondition( "device handles must be given to execute parallel computations"); } - TF_ASSIGN_OR_RETURN( - auto replicas, - Replicas(*execute_backend_, arg->requests(i).device_handle())); - se::StreamExecutor* executor = replicas[0]; - CHECK(executor != nullptr); + std::vector executors; + for (const auto& device_handle : execution_options.device_handles()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, device_handle)); + se::StreamExecutor* executor = replicas[0]; + CHECK(executor != nullptr); + executors.push_back(executor); + } // Resolve the UserComputation object associated with the requested // computation and compute the program shape. @@ -658,10 +668,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. + // In the case of partitioned computations, assume all arguments go on the + // zeroth core. TF_ASSIGN_OR_RETURN( std::vector arg_allocations, ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), - executor->device_ordinal())); + executors[0]->device_ordinal())); std::vector arguments; arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { @@ -678,11 +690,15 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(arguments); + all_arguments.insert(all_arguments.end(), executors.size() - 1, {}); versioned_handles.push_back(versioned_handle); module_configs.push_back(std::move(module_config)); - computation_names.push_back(user_computation->name()); - executors.push_back(executor); - device_handles.push_back(arg->requests(i).device_handle()); + computation_names.insert(computation_names.end(), executors.size(), + user_computation->name()); + all_executors.push_back(executors); + device_handles.insert(device_handles.end(), + execution_options.device_handles().begin(), + execution_options.device_handles().end()); } // Build the user computations into HloModules and compile to generate the @@ -690,7 +706,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, TF_ASSIGN_OR_RETURN( std::vector> executables, BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), executors)); + execute_backend_.get(), all_executors)); std::vector executable_ptrs; executable_ptrs.reserve(executables.size()); for (const auto& executable : executables) { @@ -752,6 +768,17 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, return InvalidArgument("computations may not be empty"); } + // If we received multiple device handles, we must partition the module. + if (arg->execution_options().device_handles_size() > 1) { + ExecuteParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); + TF_RET_CHECK(parallel_result.responses_size() > 0); + *result = parallel_result.responses(0); + return Status::OK(); + } + TF_ASSIGN_OR_RETURN( std::shared_ptr program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index bb86a53c62e05bb62b93bbac88c2ca251ad0439a..2452259f736054b5bf1f03fc5103d65eded7f398 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -277,8 +277,7 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - bool has_hybrid_result = false); + const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. StatusOr> BuildExecutable( @@ -294,7 +293,7 @@ class Service : public ServiceInterface { std::vector versioned_handles, std::vector> module_configs, Backend* backend, - std::vector executors); + std::vector> executors); // Similar to BuildExecutable, but look in the compilation cache for the // executable first. If the executable is not in the cache, it is built and diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index ffd80188274e17b4345c1d9d4f47ea04195798b2..6be6b77e85ba3b758e062f3c9588d3ed6276bd30 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -57,8 +57,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { return UNOP_IS_FINITE; case HloOpcode::kLog: return UNOP_LOG; - case HloOpcode::kLogicalNot: - return UNOP_LOGICAL_NOT; + case HloOpcode::kNot: + return UNOP_NOT; case HloOpcode::kNegate: return UNOP_NEGATE; case HloOpcode::kRoundNearestAfz: @@ -113,10 +113,16 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_POW; case HloOpcode::kRemainder: return BINOP_REM; - case HloOpcode::kLogicalOr: - return BINOP_LOGICAL_OR; - case HloOpcode::kLogicalAnd: - return BINOP_LOGICAL_AND; + case HloOpcode::kOr: + return BINOP_OR; + case HloOpcode::kAnd: + return BINOP_AND; + case HloOpcode::kShiftLeft: + return BINOP_SHIFT_LEFT; + case HloOpcode::kShiftRightArithmetic: + return BINOP_SHIFT_RIGHT_ARITHMETIC; + case HloOpcode::kShiftRightLogical: + return BINOP_SHIFT_RIGHT_LOGICAL; default: LOG(FATAL) << "unhandled opcode " << opcode; } @@ -322,11 +328,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case UNOP_SORT: return arg; - case UNOP_LOGICAL_NOT: - if (arg.element_type() != PRED) { + case UNOP_NOT: + if (arg.element_type() != PRED && + !primitive_util::IsIntegralType(arg.element_type())) { return InvalidArgument( - "expected pred element type in argument to logical-not operation; " - "got %s", + "expected pred or an integral element type in argument to not " + "operation; got %s", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; @@ -457,7 +464,10 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) { return InvalidArgument( - "the rank of the operand and the padding configuration do not match."); + "The rank of the operand and the padding configuration do not match: " + "%s vs %s", + ShapeUtil::HumanString(operand_shape).c_str(), + padding_config.ShortDebugString().c_str()); } if (operand_shape.element_type() != padding_value_shape.element_type()) { return InvalidArgument( @@ -747,20 +757,23 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case BINOP_DIV: case BINOP_REM: case BINOP_MUL: + case BINOP_SHIFT_LEFT: + case BINOP_SHIFT_RIGHT_ARITHMETIC: + case BINOP_SHIFT_RIGHT_LOGICAL: return InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions); - case BINOP_LOGICAL_AND: - case BINOP_LOGICAL_OR: - if (lhs.element_type() != PRED) { + case BINOP_AND: + case BINOP_OR: + if (lhs.element_type() != PRED && + !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( - "expected pred element type in argument to logical and/or " - "operation; got %s", + "expected pred or integral type in argument to and/or operation; " + "got %s", PrimitiveType_Name(lhs.element_type()).c_str()); } return InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions); - case BINOP_EQ: case BINOP_GE: case BINOP_GT: @@ -1406,8 +1419,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Verifies that the input and window dimensions are a permutation of // the dimension numbers. std::vector input_dnums(num_dims); - input_dnums[0] = dnums.batch_dimension(); - input_dnums[1] = dnums.feature_dimension(); + input_dnums[0] = dnums.input_batch_dimension(); + input_dnums[1] = dnums.input_feature_dimension(); std::copy(dnums.spatial_dimensions().begin(), dnums.spatial_dimensions().end(), input_dnums.begin() + 2); std::sort(input_dnums.begin(), input_dnums.end()); @@ -1447,8 +1460,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int i = 0; i < num_spatial_dims; ++i) { input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i)); } - const int64 input_features = lhs.dimensions(dnums.feature_dimension()); - const int64 input_batch = lhs.dimensions(dnums.batch_dimension()); + const int64 input_features = lhs.dimensions(dnums.input_feature_dimension()); + const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension()); std::vector kernel_spatial_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -1490,8 +1503,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /*allow_negative_padding=*/true)); std::vector dimensions(num_dims); - dimensions[dnums.batch_dimension()] = input_batch; - dimensions[dnums.feature_dimension()] = kernel_output_features; + dimensions[dnums.output_batch_dimension()] = input_batch; + dimensions[dnums.output_feature_dimension()] = kernel_output_features; for (int i = 0; i < num_spatial_dims; ++i) { dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i); } @@ -1894,11 +1907,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( Shape inferred_shape = ShapeUtil::MakeShape(operand.element_type(), new_sizes); + VLOG(3) << "Reshape inferred shape: " + << ShapeUtil::HumanString(inferred_shape); if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "reshape operation has mismatched element counts: from=%lld to=%lld", - ShapeUtil::ElementsIn(operand), ShapeUtil::ElementsIn(inferred_shape)); + "reshape operation has mismatched element counts: from=%lld (%s) " + "to=%lld (%s)", + ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), + ShapeUtil::ElementsIn(inferred_shape), + ShapeUtil::HumanString(inferred_shape).c_str()); } std::vector indices(ShapeUtil::Rank(operand)); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 7c9c7e8d6ac32d868ad18d892b3f8e9dc18b8259..8df4a73229df25043d5490b0336b65955d4f4eed 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -352,8 +352,10 @@ TEST_F(ShapeInferenceTest, Convolve) { // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); @@ -392,8 +394,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4}); - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); @@ -433,8 +437,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); @@ -475,8 +481,10 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2}); ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(3); - dnums.set_feature_dimension(2); + dnums.set_input_batch_dimension(3); + dnums.set_output_batch_dimension(3); + dnums.set_input_feature_dimension(2); + dnums.set_output_feature_dimension(2); dnums.add_spatial_dimensions(0); dnums.add_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0 diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 3f62501bb5c46f397b5cd96688e50a43f8e83428..b3506b72bf5ab1aa27704c18c8a1dc69881caf71 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -58,8 +58,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kIsFinite; case UNOP_LOG: return HloOpcode::kLog; - case UNOP_LOGICAL_NOT: - return HloOpcode::kLogicalNot; + case UNOP_NOT: + return HloOpcode::kNot; case UNOP_NEGATE: return HloOpcode::kNegate; case UNOP_ROUND_NEAREST_AFZ: @@ -111,10 +111,16 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kPower; case BINOP_REM: return HloOpcode::kRemainder; - case BINOP_LOGICAL_OR: - return HloOpcode::kLogicalOr; - case BINOP_LOGICAL_AND: - return HloOpcode::kLogicalAnd; + case BINOP_OR: + return HloOpcode::kOr; + case BINOP_AND: + return HloOpcode::kAnd; + case BINOP_SHIFT_LEFT: + return HloOpcode::kShiftLeft; + case BINOP_SHIFT_RIGHT_ARITHMETIC: + return HloOpcode::kShiftRightArithmetic; + case BINOP_SHIFT_RIGHT_LOGICAL: + return HloOpcode::kShiftRightLogical; default: LOG(FATAL) << "unhandled operation " << binop; } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e45b839afd2a9666215744f904dfbed5eca0a41b..b02d906d93e8854fc33fc49514f97e6a1b81b110 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -127,6 +127,22 @@ cc_library( ], ) +cc_library( + name = "hlo_verified_test_base", + testonly = True, + srcs = ["hlo_verified_test_base.cc"], + hdrs = ["hlo_verified_test_base.h"], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], @@ -373,6 +389,7 @@ xla_test( name = "params_test", srcs = ["params_test.cc"], shard_count = 30, + tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -1394,8 +1411,10 @@ xla_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "//third_party/eigen3", ], ) @@ -1461,6 +1480,7 @@ xla_test( xla_test( name = "local_client_execute_test", srcs = ["local_client_execute_test.cc"], + tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 24bccf686349e60f2a4fc4b3f2f6ef836f07d107..a62b13e04ff35b06846039d7665dfc8e4205eec2 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -496,58 +496,315 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalAnd) { +XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.LogicalAnd(a, b); + auto out = builder.And(a, b); ComputeAndCompareR1(&builder, {false, false, false, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalAndZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{false, false}, {true, true}}); + auto b = builder.ConstantR2({{false, true}, {false, true}}); + auto out = builder.And(a, b); + + Array2D expected_array({{false, false}, {false, true}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.LogicalAnd(a, b); + auto out = builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalOr) { +XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, -1, -8}); + auto b = builder.ConstantR1({5, -7, 12}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {0, -7, 8}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, -5}, {-1, 5}}); + auto b = builder.ConstantR2({{1, -6}, {4, 5}}); + auto out = builder.And(a, b); + + Array2D expected_array({{0, -6}, {4, 5}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, 1, 8}); + auto b = builder.ConstantR1({5, 7, 12}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {0, 1, 8}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, 1}, {3, 8}}); + auto b = builder.ConstantR2({{1, 0}, {7, 6}}); + auto out = builder.And(a, b); + + Array2D expected_array({{0, 0}, {3, 0}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.LogicalOr(a, b); + auto out = builder.Or(a, b); ComputeAndCompareR1(&builder, {false, true, true, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalOrZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{false, false}, {true, true}}); + auto b = builder.ConstantR2({{false, true}, {false, true}}); + auto out = builder.Or(a, b); + + Array2D expected_array({{false, true}, {true, true}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.LogicalOr(a, b); + auto out = builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalNot) { +XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, -1, 8}); + auto b = builder.ConstantR1({5, -7, 4}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {5, -1, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, -1}, {8, 8}}); + auto b = builder.ConstantR2({{5, -7}, {4, 1}}); + auto out = builder.Or(a, b); + + Array2D expected_array({{5, -1}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, 1, 8}); + auto b = builder.ConstantR1({5, 7, 4}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {5, 7, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, 1}, {8, 8}}); + auto b = builder.ConstantR2({{5, 7}, {4, 1}}); + auto out = builder.Or(a, b); + + Array2D expected_array({{5, 7}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, true, true, false}); - auto out = builder.LogicalNot(a); + auto out = builder.Not(a); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalNotZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{false, true}, {true, false}}); + auto out = builder.Not(a); + + Array2D expected_array({{true, false}, {false, true}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); - auto out = builder.LogicalNot(a); + auto out = builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-1, 0, 1}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {0, -1, -2}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{-1, 0}, {1, 8}}); + auto out = builder.Not(a); + + Array2D expected_array({{0, -1}, {-2, -9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, 4294967295}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {4294967295, 0}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, 4294967295}, {1, 4294967294}}); + auto out = builder.Not(a); + + Array2D expected_array({{4294967295, 0}, {4294967294, 1}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({static_cast(0x12345678), + static_cast(0xF0001000), 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15}); + auto out = builder.ShiftLeft(a, b); + + ComputeAndCompareR1( + &builder, + {static_cast(0x23456780), 0x00100000, 0x4, 0x180, 2523136}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2}); + auto out = builder.ShiftRightArithmetic(a, b); + + ComputeAndCompareR1(&builder, + {static_cast(0xF9234567), + static_cast(0x00100010), 0, 0, 19}, + {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5}); + auto out = builder.ShiftRightLogical(a, b); + + ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0x12345678, 0xF0001000, 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15}); + auto out = builder.ShiftLeft(a, b); + + ComputeAndCompareR1( + &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2}); + auto out = builder.ShiftRightArithmetic(a, b); + + ComputeAndCompareR1(&builder, {0xF9234567, 0x00100010, 0, 0, 19}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5}); + auto out = builder.ShiftRightLogical(a, b); + + ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 505fa059f28599c7c934749e06be8a7185c66cae..03f5e08315bfed2bcb43ebb7098aaa0b97228605 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -159,7 +159,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { } // Tests implicit broadcasting of PREDs. -XLA_TEST_F(BroadcastSimpleTest, LogicalAnd2DTo3D_Pred) { +XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { ComputationBuilder b(client_, TestName()); Array2D x_vals(2, 1); @@ -174,7 +174,7 @@ XLA_TEST_F(BroadcastSimpleTest, LogicalAnd2DTo3D_Pred) { ComputationDataHandle x, y; auto x_data = CreateR2Parameter(x_vals, 0, "x", &b, &x); auto y_data = CreateR3Parameter(y_vals, 1, "y", &b, &y); - b.LogicalAnd(x, y, /*broadcast_dimensions=*/{1, 2}); + b.And(x, y, /*broadcast_dimensions=*/{1, 2}); Array3D expected(2, 2, 1); expected(0, 0, 0) = false; diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 9f3b66e256dbb351b76a2e66912d3100495101be..a60d3e50bd4dc78ed8715f8d7814668b95f3d38a 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -40,7 +40,7 @@ namespace { Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) { StatusOr result = ClientLibrary::GetOrCreateLocalClient(client_options); - TF_CHECK_OK(result.status()) << "could not create local client for testing"; + TF_CHECK_OK(result.status()) << " could not create local client for testing"; return result.ValueOrDie(); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 83882ca75e93ee9edec8e292991b53f1af57bb62..b0a63bccbb93f226175beff2e30e2a243fdca1d3 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -39,7 +39,8 @@ class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {}; // Tests the convolution operation with invalid input dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 2, 0, 2, 2, 3, 0, 1, 2, + 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("input are not unique")); @@ -48,7 +49,8 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { // Tests the convolution operation with invalid weight dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 2, 3, 2, 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 1, 0, 1, 2, 3, 2, 3, 2, + 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("weight are not unique")); @@ -73,14 +75,18 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, ConvolutionDimensionNumbers dim_nums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); // Swap batch_dimension and feature_dimension. - int64 tmp = dim_nums.batch_dimension(); - dim_nums.set_batch_dimension(dim_nums.feature_dimension()); - dim_nums.set_feature_dimension(tmp); + int64 old_input_batch_dim = dim_nums.input_batch_dimension(); + int64 old_output_batch_dim = dim_nums.output_batch_dimension(); + dim_nums.set_input_batch_dimension(dim_nums.input_feature_dimension()); + dim_nums.set_output_batch_dimension(dim_nums.output_feature_dimension()); + dim_nums.set_input_feature_dimension(old_input_batch_dim); + dim_nums.set_output_feature_dimension(old_output_batch_dim); // Swap kernel_input_feature_dimension and kernel_output_feature_dimension. - tmp = dim_nums.kernel_input_feature_dimension(); + int64 old_kernel_input_feature_dim = + dim_nums.kernel_input_feature_dimension(); dim_nums.set_kernel_input_feature_dimension( dim_nums.kernel_output_feature_dimension()); - dim_nums.set_kernel_output_feature_dimension(tmp); + dim_nums.set_kernel_output_feature_dimension(old_kernel_input_feature_dim); builder.ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid, dim_nums); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 7d06cce0c8f82e4a1c4fb847638613594257b80f..a7089c2897bee2a10b698df910b4805456257949 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -418,11 +418,13 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { // Tensorflow dimension numbers for 3D convolution. ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); - dnums.set_feature_dimension(4); + dnums.set_input_feature_dimension(4); + dnums.set_output_feature_dimension(4); dnums.add_kernel_spatial_dimensions(0); dnums.add_kernel_spatial_dimensions(1); dnums.add_kernel_spatial_dimensions(2); @@ -469,10 +471,12 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { // Tensorflow dimension numbers for 2D convolution. ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); dnums.add_kernel_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(2); @@ -520,9 +524,11 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_Valid) { // Tensorflow dimension numbers for 2D convolution. ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); - dnums.set_feature_dimension(2); + dnums.set_input_feature_dimension(2); + dnums.set_output_feature_dimension(2); dnums.add_kernel_spatial_dimensions(0); dnums.set_kernel_input_feature_dimension(1); dnums.set_kernel_output_feature_dimension(2); diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 145918db3e5e57c39054706d53bbfb7648af3143..9b36e3722b8f8a5d01c426425fdfb0c4b9ae3a16 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -974,10 +974,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); @@ -1014,10 +1016,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); @@ -1054,10 +1058,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); @@ -1091,10 +1097,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 3bf9ccb19745b9e91d99614792dbec0443818f2b..a8f6488996087b57e3121ce2c7de918070950c72 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -17,8 +17,12 @@ limitations under the License. #include #include #include +#include #include +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" @@ -37,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -250,6 +255,42 @@ XLA_TEST_F(FusionTest, Parameter) { ErrorSpec(1e-4)); } +XLA_TEST_F(FusionTest, RandomizedParallelPartition) { + // Tests parallel partitioning of a fusion instruction. + // Create shape with random outer dimension size to generate random parallel + // partition counts for each test run. + const int seed = tensorflow::testing::RandomSeed(); + LOG(INFO) << "RandomizedParallelPartition seed: " << seed; + std::mt19937 generator(seed); + std::uniform_int_distribution distribution(128, 1024); + const int64 rand_dim0_size = distribution(generator); + const int64 dim1_size = 1024; + Shape shape = + ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); + // Build simple fusion computation: y = x^2 (elementwise). + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = CreateNewModule(); + + auto two = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto x = + builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {})); + auto y = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x)); + + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two}, + HloInstruction::FusionKind::kLoop); + // Compute result. + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + // Every element of result should be y = x^2 = 4.0. + for (int i = 0; i < rand_dim0_size; ++i) { + for (int j = 0; j < dim1_size; ++j) { + EXPECT_EQ(4.0, result->Get({i, j})); + } + } +} + XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); @@ -722,47 +763,104 @@ void BM_ParallelFusion(int num_iters) { auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); StreamExecutorMemoryAllocator allocator(platform, executors); - const int64 intra_op_parallelism_threads = 16; + const int64 intra_op_parallelism_threads = 24; xla::LocalClientOptions client_options; client_options.set_platform(platform); client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads); auto client = ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); - const int64 dim_size = 1024; - // Create a simple fusable elementwise computation. + auto* transfer_manager = + TransferManager::GetForPlatform(platform).ValueOrDie(); + int device_ordinal = client->default_device_ordinal(); + + // Computation shape parameters. + const int64 param0_dim0 = 1024; + const int64 param0_dim1 = 1024; + const int64 param1_dim0 = 1024; + const int64 param1_dim1 = 1024; + const int64 param2_dim0 = 1024; + const int64 param2_dim1 = 1024; + + // Create computation. ComputationBuilder builder(client, "ParallelFusion"); - Shape input_shape = ShapeUtil::MakeShape(F32, {dim_size, dim_size}); - auto input0 = builder.Broadcast(builder.ConstantR0(1.5f), - AsInt64Slice(input_shape.dimensions())); - auto input1 = builder.Broadcast(builder.ConstantR0(2.0f), - AsInt64Slice(input_shape.dimensions())); - auto input2 = builder.Broadcast(builder.ConstantR0(3.0f), - AsInt64Slice(input_shape.dimensions())); - auto x = builder.Mul(input0, input1); - auto y = builder.Add(x, input2); + Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); + auto param0 = builder.Parameter(0, shape0, "param0"); + Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); + auto param1 = builder.Parameter(1, shape1, "param1"); + Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1}); + auto param2 = builder.Parameter(2, shape2, "param2"); + + auto x = builder.Mul(param0, param1); + auto y = builder.Add(x, param2); auto computation = builder.Build().ConsumeValueOrDie(); + // Transfer literals to device. + auto buffer0 = + ScopedShapedBuffer::Allocate(shape0, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param0_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param0_literal, buffer0->mutable_buffer({}))); + + auto buffer1 = + ScopedShapedBuffer::Allocate(shape1, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param1_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param1_literal, buffer1->mutable_buffer({}))); + + auto buffer2 = + ScopedShapedBuffer::Allocate(shape2, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param2_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param2_literal, buffer2->mutable_buffer({}))); + + // Build executable. std::unique_ptr executable = - client->Compile(computation, {}, ExecutableBuildOptions()) + client + ->Compile(computation, + {&buffer0->shape(), &buffer1->shape(), &buffer2->shape()}, + ExecutableBuildOptions()) .ConsumeValueOrDie(); - // Run some warm-up executions. + se::Stream stream(executors[client->default_device_ordinal()]); + stream.Init(); + + // Initialize thread pool. + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", + intra_op_parallelism_threads); + tensorflow::EigenThreadPoolWrapper tp(&pool); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + // Initialize ExecutableRunOptions. ExecutableRunOptions options; - options.set_allocator(&allocator); + options.set_allocator(&allocator).set_stream(&stream); + options.set_intra_op_thread_pool(&device); + + // Run some warm-up executions. const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); ASSERT_TRUE(result.ok()); } // Run benchmark. - tensorflow::testing::BytesProcessed(static_cast(num_iters) * dim_size * - dim_size * sizeof(float)); + const int64 total_bytes = param0_dim0 * param0_dim0 + + param1_dim0 * param1_dim0 + + param2_dim0 * param2_dim0; + tensorflow::testing::BytesProcessed(static_cast(num_iters) * + total_bytes * sizeof(float)); tensorflow::testing::UseRealTime(); tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); ASSERT_TRUE(result.ok()); } } diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc new file mode 100644 index 0000000000000000000000000000000000000000..31060b9e80fcd50aefdedca27c70ec8a9b8be743 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +/*static*/ int64 HloVerifiedTestBase::DefaultShapeSize(const Shape& shape) { + constexpr int64 kPointerSize = sizeof(void*); + if (ShapeUtil::IsOpaque(shape)) { + return kPointerSize; + } + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + +HloVerifiedTestBase::HloVerifiedTestBase() : shape_size_fn_(DefaultShapeSize) {} + +HloVerifiedTestBase::~HloVerifiedTestBase() { + // We can't call the ASSERT or EXPECT test macros in destructors, so we + // perform HLO verification in TearDown, and use the CHECK here to ensure + // users don't accidentally override the verification. + CHECK(tear_down_called_) + << "TearDown was never called; subclasses of HloVerifiedTestBase that " + << "override TearDown must call the superclass TearDown."; +} + +void HloVerifiedTestBase::TearDown() { + EXPECT_FALSE(tear_down_called_) + << "TearDown called more than once; it should be called exactly once."; + tear_down_called_ = true; + if (module_) { + HloVerifier verifier(shape_size_fn_); + xla::StatusOr mutated = verifier.Run(module_.get()); + if (!mutated.ok()) { + ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); + } else { + EXPECT_FALSE(mutated.ValueOrDie()) + << "HloVerifier should never mutate the HloModule"; + } + } + HloTestBase::TearDown(); +} + +HloModule& HloVerifiedTestBase::module() { + if (!module_) { + module_ = CreateNewModule(); + } + return *module_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h new file mode 100644 index 0000000000000000000000000000000000000000..b3d6b5af3b46f932707abf309669d23c327d1334 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { + +// A base class for HLO tests that stores a default HloModule, and automatically +// performs verification on that module on tear-down. +class HloVerifiedTestBase : public HloTestBase { + public: + // Returns the size in bytes of the given shape, using a default pointer size. + static int64 DefaultShapeSize(const Shape& shape); + + protected: + HloVerifiedTestBase(); + ~HloVerifiedTestBase() override; + + // Performs verification on the default HloModule returned by module(). + // Automatically called by the testing framework for each test. + // + // REQUIRED: subclasses that override TearDown() must call this explicitly. + void TearDown() override; + + // Returns the default HloModule, lazily creating it if necessary via + // HloTestBase::CreateNewModule(). + HloModule& module(); + + // Sets the shape-size function used during hlo verification. If this isn't + // called, DefaultShapeSize is used instead. + void SetShapeSizeFn(std::function shape_size_fn) { + shape_size_fn_ = std::move(shape_size_fn); + } + + private: + std::unique_ptr module_; // Lazily populated. Access via module(). + std::function shape_size_fn_; + bool tear_down_called_ = false; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 4d8b50fbbf715e8d491667ecb4f4f336ef2d8a68..2876a79dd8b80f5ac1551df4184c853533fb59df 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -39,28 +39,60 @@ limitations under the License. namespace xla { -/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, - const Shape& actual) { - ASSERT_EQ(ShapeUtil::IsTuple(expected), ShapeUtil::IsTuple(actual)); +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( + const Shape& expected, const Shape& actual) { + if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { + return ::testing::AssertionFailure() + << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected) + << " got: " << ShapeUtil::HumanString(actual); + } if (ShapeUtil::IsTuple(expected)) { - ASSERT_EQ(ShapeUtil::TupleElementCount(expected), - ShapeUtil::TupleElementCount(actual)); + if (ShapeUtil::TupleElementCount(expected) != + ShapeUtil::TupleElementCount(actual)) { + return ::testing::AssertionFailure() + << "want tuple element count: " + << ShapeUtil::TupleElementCount(expected) + << " got tuple element count: " + << ShapeUtil::TupleElementCount(actual); + } for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + ::testing::AssertionResult result = + EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + if (!result) { + return result; + } } } else { - ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual)); - ASSERT_EQ(expected.element_type(), actual.element_type()) - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size()); + if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + return ::testing::AssertionFailure() + << "want rank of: " << ShapeUtil::HumanString(expected) + << " got rank of: " << ShapeUtil::HumanString(actual); + } + if (expected.element_type() != actual.element_type()) { + return ::testing::AssertionFailure() + << PrimitiveType_Name(expected.element_type()) << " vs " + << PrimitiveType_Name(actual.element_type()); + } + if (expected.dimensions_size() != actual.dimensions_size()) { + return ::testing::AssertionFailure() + << "want dimensions_size " << expected.dimensions_size() + << " got dimensions_size " << actual.dimensions_size(); + } for (int i = 0; i < expected.dimensions_size(); ++i) { - ASSERT_EQ(expected.dimensions(i), actual.dimensions(i)) - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); + if (expected.dimensions(i) != actual.dimensions(i)) { + return ::testing::AssertionFailure() + << "mismatch in dimension #" << i + << " expected: " << ShapeUtil::HumanString(expected) + << " actual: " << ShapeUtil::HumanString(actual); + } } } + return ::testing::AssertionSuccess(); +} + +/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, + const Shape& actual) { + ASSERT_TRUE(EqualShapes(expected, actual)); } /* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts( @@ -263,7 +295,14 @@ class NearComparator { VLOG(1) << "actual:"; XLA_VLOG_LINES(1, actual.ToString()); - LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape()); + // If the shapes mismatch, we simply fail the expectation instead of + // printing out data, as it's a type error rather than a value error. + ::testing::AssertionResult equal_shapes = + LiteralTestUtil::EqualShapes(expected.shape(), actual.shape()); + if (!equal_shapes) { + EXPECT_TRUE(equal_shapes); + return false; + } // Set up members used during the comparison. num_miscompares_ = 0; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index f645c4e8dcda73806a4204876716b93aa5fb7185..467d44b857b74d2a38e9b3f8a32a9b1d39a4a10d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -50,6 +50,8 @@ class LiteralTestUtil { public: // Asserts that the given shapes have the same rank, dimension sizes, and // primitive types. + static ::testing::AssertionResult EqualShapes(const Shape& expected, + const Shape& actual); static void AssertEqualShapes(const Shape& expected, const Shape& actual); // Asserts that the provided shapes are equal as defined in AssertEqualShapes diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 89a6530aa6d0fb5011c660984bd24ebf442e09a9..c74213f7f9198741770713aa950e78f2e5ec951d 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -814,7 +814,7 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { test_to_device_and_back(*Literal::CreateR0(true)); test_to_device_and_back(*Literal::CreateR1({1.0, 42.0, 744.4})); test_to_device_and_back( - *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); test_to_device_and_back(*Literal::CreateR2({{2, 1}, {4444, 56}})); // Null shape (empty tuple). @@ -835,6 +835,30 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { Literal::CreateR0(false).get()})); } +XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { + // Test copying Literals to the device as ShapedBuffers, then copying them + // back again to Literals for 64-bit values. + auto test_to_device_and_back = [this](const Literal& literal) { + TF_ASSERT_OK_AND_ASSIGN( + auto shaped_buffer, + local_client_->LiteralToShapedBuffer( + literal, local_client_->default_device_ordinal(), allocator_)); + TF_ASSERT_OK_AND_ASSIGN( + auto transferred_literal, + local_client_->ShapedBufferToLiteral(*shaped_buffer)); + EXPECT_EQ(literal, *transferred_literal); + }; + + test_to_device_and_back( + *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(*Literal::CreateR2({{2, 1}, {4444, 56}})); + test_to_device_and_back( + *Literal::CreateR2({{20000000000ULL, 1}, {4444, 56}})); + test_to_device_and_back( + *Literal::MakeTuple({Literal::CreateR1({1.0, -42.0}).get(), + Literal::CreateR0(123456789000LL).get()})); +} + // Benchmark that measures the overhead of the LocalClient API when running a // trivial computation void BM_LocalClientOverhead(int num_iters) { diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 2271f32c5946f3d3e7e6b43b089e68ab3101b61b..b48b3a2bdbb0dac3cc7db5f93aa9172dcf47bc02 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -120,10 +120,10 @@ class ReduceTest : public ClientLibraryTestBase { Computation reduce; if (and_reduce) { init_value = builder.ConstantR0(true); - reduce = CreateScalarLogicalAndComputation(&builder); + reduce = CreateScalarAndComputation(&builder); } else { init_value = builder.ConstantR0(false); - reduce = CreateScalarLogicalOrComputation(&builder); + reduce = CreateScalarOrComputation(&builder); } builder.Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); @@ -729,16 +729,14 @@ XLA_TEST_F(ReduceTest, VectorizedReduce_Min) { std::numeric_limits::max()); } -XLA_TEST_F(ReduceTest, VectorizedReduce_LogicalAnd) { - RunVectorizedReduceTestForType(CreateScalarLogicalAndComputation, - [](bool a, bool b) { return a && b; }, - true); +XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) { + RunVectorizedReduceTestForType( + CreateScalarAndComputation, [](bool a, bool b) { return a && b; }, true); } -XLA_TEST_F(ReduceTest, VectorizedReduce_LogicalOr) { - RunVectorizedReduceTestForType(CreateScalarLogicalOrComputation, - [](bool a, bool b) { return a || b; }, - false); +XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) { + RunVectorizedReduceTestForType( + CreateScalarOrComputation, [](bool a, bool b) { return a || b; }, false); } class ReduceR3ToR2Test : public ReduceTest, diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 7b7f2687286916aff9c47a7c165619bbe84368e8..6c9b62b48d8bb2ad93b2ce98839e5e52d8eaa8cc 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -76,6 +76,20 @@ class ReduceWindowTest : public ClientLibraryTestBase { ComputationBuilder builder_; }; +TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { + const auto input = builder_.ConstantR1({1, 1, 1, 1}); + const auto init_value = builder_.ConstantR0(0); + TF_ASSERT_OK(builder_.first_error()); + builder_.ReduceWindow(input, init_value, + CreateScalarAddComputation(F32, &builder_), + /*window_dimensions=*/{1, 2}, + /*window_strides=*/{1}, Padding::kValid); + ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) + << builder_.first_error(); + ASSERT_THAT(builder_.first_error().error_message(), + ::testing::HasSubstr("Want input dimensions size")); +} + TEST_F(ReduceWindowTest, Min3In5Stride2) { const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); ReduceWindowMin(input, {3}, {2}, Padding::kValid); diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index bb7160e3a03053a4f3d8da712c1424e50f37dfeb..72c68f24a0a954deb0564e9a0e924edfaf5b5484 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -47,7 +47,7 @@ class ReshapeTest : public ClientLibraryTestBase { }; // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial1x1) { +XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2({{1.0}}); builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); @@ -55,6 +55,22 @@ XLA_TEST_F(ReshapeTest, Trivial1x1) { ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); } +XLA_TEST_F(ReshapeTest, CollapseTrivialR1EmptyDims) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({1.0}); + builder.Collapse(/*operand=*/a, /*dimensions=*/{}); + + ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); +} + +XLA_TEST_F(ReshapeTest, CollapseTrivialR1OnlyDim) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({1.0}); + builder.Collapse(/*operand=*/a, /*dimensions=*/{0}); + + ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); +} + // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 77d1c019f3a23f79237e624dabf8972a6c1d3c72..b5e7570778ffeca66cc15d7cd2b153639637a647 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -459,39 +459,99 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { ComputeAndCompareR0(&builder, 2, {}); } -XLA_TEST_F(ScalarComputationsTest, LogicalAnd) { +XLA_TEST_F(ScalarComputationsTest, AndBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { ComputationBuilder builder(client_, TestName()); - builder.LogicalAnd(builder.ConstantR0(x), - builder.ConstantR0(y)); + builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x && y, {}); } } } -XLA_TEST_F(ScalarComputationsTest, LogicalOr) { +XLA_TEST_F(ScalarComputationsTest, AndS32) { + for (int32 x : {0, 8}) { + for (int32 y : {1, -16}) { + ComputationBuilder builder(client_, TestName()); + builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x & y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, AndU32) { + for (uint32 x : {0, 8}) { + for (uint32 y : {1, 16}) { + ComputationBuilder builder(client_, TestName()); + builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x & y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, OrBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { ComputationBuilder builder(client_, TestName()); - builder.LogicalOr(builder.ConstantR0(x), - builder.ConstantR0(y)); + builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x || y, {}); } } } -XLA_TEST_F(ScalarComputationsTest, LogicalNot) { +XLA_TEST_F(ScalarComputationsTest, OrS32) { + for (int32 x : {0, 8}) { + for (int32 y : {1, -16}) { + ComputationBuilder builder(client_, TestName()); + builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x | y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, OrU32) { + for (uint32 x : {0, 8}) { + for (uint32 y : {1, 16}) { + ComputationBuilder builder(client_, TestName()); + builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x | y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, NotBool) { for (bool x : {false, true}) { ComputationBuilder builder(client_, TestName()); - builder.LogicalNot(builder.ConstantR0(x)); + builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, !x, {}); } } +XLA_TEST_F(ScalarComputationsTest, NotS32) { + for (int32 x : {-1, 0, 1}) { + ComputationBuilder builder(client_, TestName()); + builder.Not(builder.ConstantR0(x)); + + ComputeAndCompareR0(&builder, ~x, {}); + } +} + +XLA_TEST_F(ScalarComputationsTest, NotU32) { + for (uint32 x : {0, 1, 2}) { + ComputationBuilder builder(client_, TestName()); + builder.Not(builder.ConstantR0(x)); + + ComputeAndCompareR0(&builder, ~x, {}); + } +} + XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { ComputationBuilder builder(client_, TestName()); builder.Select(builder.ConstantR0(true), // The predicate. diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index bb2d90fa94abbf52c340d366ddc55f7bdefb6543..71a1b0abee51ba2819daed23208b0da8d5107207 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -169,7 +169,7 @@ TEST_F(WhileTest, WhileWithPredicateResult) { { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); - auto result = builder.LogicalOr(prev, builder.ConstantR0(true)); + auto result = builder.Or(prev, builder.ConstantR0(true)); body = builder.Build().ConsumeValueOrDie(); } @@ -437,7 +437,7 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto pred = builder.GetTupleElement(prev, 1); - auto new_pred = builder.LogicalOr(pred, builder.ConstantR0(true)); + auto new_pred = builder.Or(pred, builder.ConstantR0(true)); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index aa297ac171d76d73d4c71c7cdfd2c2e2b9fd9a3d..5ede37b8737bd4fa6235464ddeb6382af17c8a80 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -86,9 +86,9 @@ void RealMain(tensorflow::gtl::ArraySlice args) { layouts.push_back(&program_shape->parameters(i)); } StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*has_hybrid_result=*/true); + local_service->CompileExecutable(computation.handle(), layouts, + &program_shape->result(), + /*device_ordinal=*/0); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 2a3a8803283c62d12d8e2d213aa1730e8bd33244..78d8fb1f4330aed899ca917e66fae819a002b3a9 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -61,9 +61,9 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { layouts.push_back(&program_shape->parameters(i)); } StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*has_hybrid_result=*/true); + local_service->CompileExecutable(computation.handle(), layouts, + &program_shape->result(), + /*device_ordinal=*/0); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 1c7361105595ca25cd130dbb890b9e2cb694a7ac..2624ef0252fd9482a600fe3aec07f7f328a86d69 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -336,4 +336,13 @@ std::vector> CommonFactors( return bounds; } +string SanitizeFileName(string file_name) { + for (char& c : file_name) { + if (c == '/' || c == '\\' || c == '[' || c == ']') { + c = '_'; + } + } + return file_name; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 1a54c4029c8586099f26fa3cdd7fdcaf1d083dfa..f6c0bd1563f4d9090df94b6edd8226119194c76c 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -361,6 +361,9 @@ int64 Product(tensorflow::gtl::ArraySlice xs); std::vector> CommonFactors( tensorflow::gtl::ArraySlice a, tensorflow::gtl::ArraySlice b); +// Removes illegal characters from filenames. +string SanitizeFileName(string file_name); + } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc index 547b924180bf59091ebd552618bf6bd5be9cd6a7..288479c893855742f7aa76fab532c5ca8f942e3c 100644 --- a/tensorflow/compiler/xla/util_test.cc +++ b/tensorflow/compiler/xla/util_test.cc @@ -122,5 +122,12 @@ TEST(UtilTest, CommonFactors) { } } +TEST(UtilTest, SanitizeFileName) { + EXPECT_EQ(SanitizeFileName(""), ""); + EXPECT_EQ(SanitizeFileName("abc"), "abc"); + EXPECT_EQ(SanitizeFileName("/\\[]"), "____"); + EXPECT_EQ(SanitizeFileName("/A\\B[C]"), "_A_B_C_"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 4840ddb8817a37c7dabcfb27e24a2a5472f4b6a2..7f4bd26d1bcc3ff9cc002adb28d2adfcf96f59ab 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -191,6 +191,11 @@ message ExecutionOptions { uint64 seed = 3; DebugOptions debug_options = 4; + + // This optional field specifies a particular set of devices to run the + // computation on. The computation will be partitioned across these devices. + // If not provided, the default device will be chosen. + repeated DeviceHandle device_handles = 5; } message SnapshotComputationRequest { @@ -312,12 +317,8 @@ message ExecuteRequest { ComputationHandle computation = 1; repeated GlobalDataHandle arguments = 2; - // This optional field specifies a particular device to run the computation. - // If not provided, the default device will be chosen. - DeviceHandle device_handle = 5; - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 6; + ExecutionOptions execution_options = 5; } message ExecuteParallelRequest { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 1771a3d5deaac0388c8c3ba6a2f283231ebd3572..eae284afb768ea7bc55a21e721068aaf69a9aba3 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -392,13 +392,17 @@ message DynamicUpdateSliceRequest { } message ConvolutionDimensionNumbers { - // The number of the dimension that represents batch in the input - // (lhs) and output. - int64 batch_dimension = 1; + // The number of the dimension that represents batch in the input. + int64 input_batch_dimension = 7; - // The number of the dimension that represents features in the input - // (lhs) and output. - int64 feature_dimension = 2; + // The number of the dimension that represents features in the input. + int64 input_feature_dimension = 8; + + // The number of the dimension that represents batch in the output. + int64 output_batch_dimension = 9; + + // The number of the dimension that represents features in the output. + int64 output_feature_dimension = 10; // The dimension numbers for the spatial dimensions that the window // moves through in the input (lhs) and output. @@ -617,8 +621,8 @@ message WhileRequest { enum UnaryOperation { UNOP_INVALID = 0; - // Elementwise, logical negation - UNOP_LOGICAL_NOT = 1; + // Elementwise, logical negation on booleans and bitwise negation on ints. + UNOP_NOT = 1; // Elementwise, computes e^x. UNOP_EXP = 2; @@ -706,9 +710,13 @@ enum BinaryOperation { // Remainder operation. BINOP_REM = 17; - // Logical operators - BINOP_LOGICAL_AND = 18; - BINOP_LOGICAL_OR = 19; + // Element-wise, logical operators on booleans and bitwise operators on ints. + BINOP_AND = 18; + BINOP_OR = 19; + + BINOP_SHIFT_LEFT = 20; + BINOP_SHIFT_RIGHT_ARITHMETIC = 21; + BINOP_SHIFT_RIGHT_LOGICAL = 22; } message BinaryOpRequest { diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 2007e09e8d715dac889ce146acabb6d582bef9d8..3d580fae142f990be249fb61119d23aa3c92210c 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -53,6 +53,7 @@ py_library( "//tensorflow/contrib/linear_optimizer:sdca_ops_py", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", + "//tensorflow/contrib/losses:metric_learning_py", "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/meta_graph_transform", "//tensorflow/contrib/metrics:metrics_py", @@ -63,6 +64,7 @@ py_library( "//tensorflow/contrib/opt:opt_py", "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", + "//tensorflow/contrib/quantize:quantize_graph", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/resampler:resampler_py", @@ -103,6 +105,7 @@ cc_library( "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/nccl:nccl_kernels", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_kernels", + "//tensorflow/contrib/rnn:all_kernels", "//tensorflow/contrib/seq2seq:beam_search_ops_kernels", "//tensorflow/contrib/tensor_forest:model_ops_kernels", "//tensorflow/contrib/tensor_forest:stats_ops_kernels", @@ -124,6 +127,7 @@ cc_library( "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib", + "//tensorflow/contrib/rnn:all_ops", "//tensorflow/contrib/seq2seq:beam_search_ops_op_lib", "//tensorflow/contrib/tensor_forest:model_ops_op_lib", "//tensorflow/contrib/tensor_forest:stats_ops_op_lib", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index b50c185e3714b338ad1a9ca2af276ba51696f1ac..bf921808aa9a4694e06afcc2091b381a6fcffc49 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -56,6 +56,7 @@ from tensorflow.contrib import nn from tensorflow.contrib import opt from tensorflow.contrib import predictor from tensorflow.contrib import quantization +from tensorflow.contrib import quantize from tensorflow.contrib import reduce_slice_ops from tensorflow.contrib import resampler from tensorflow.contrib import rnn diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 8e7f1791b864bb30e4592a86e637d1603be6618b..a5057da9fd43a88575813613d6ac9d17fd2b2e28 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -191,7 +191,7 @@ def _ragged_split(tensor, pieces): def _ring_permutations(num_workers, num_subchunks, gpu_perm): - """"Generate an array of device index arrays, one for for each subchunk. + """"Generate an array of device index arrays, one for each subchunk. In the basic ring reduction algorithm there are size(T)/num_devices data chunks and each device process one chunk per tick, i.e. sending @@ -762,6 +762,8 @@ def _reduce_non_singleton(input_tensors, red_f, un_op): if len(input_tensors) > 1: return red_f(input_tensors) else: + if not un_op: + return input_tensors output_tensors = [] for t in input_tensors: with ops.colocate_with(t): @@ -835,7 +837,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): def build_shuffle_then_ring(input_tensors, gather_devices, subdiv, - red_n_op, red_op, un_op): + red_n_op, red_op, un_op=None): """Construct hybrid of Shuffle within workers, Ring across workers.""" def upper_builder(tensors): return build_ring_all_reduce(tensors, len(tensors), subdiv, [0], diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index f5710cc7c15782dfb9d0c788774f05a808590c48..80e03f20362ed41b62ce118e864ffb0acb4ab50b 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -36,6 +36,7 @@ import org.tensorflow.Operation; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.TensorFlow; +import org.tensorflow.Tensors; import org.tensorflow.types.UInt8; /** @@ -337,7 +338,7 @@ public class TensorFlowInferenceInterface { * a Java {@code String} (which is a sequence of characters). */ public void feedString(String inputName, byte[] src) { - addFeed(inputName, Tensor.create(src)); + addFeed(inputName, Tensors.create(src)); } /** @@ -346,7 +347,7 @@ public class TensorFlowInferenceInterface { * arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters). */ public void feedString(String inputName, byte[][] src) { - addFeed(inputName, Tensor.create(src)); + addFeed(inputName, Tensors.create(src)); } // Methods for taking a native Tensor and filling it with src from Java native IO buffers. @@ -616,7 +617,7 @@ public class TensorFlowInferenceInterface { private List feedNames = new ArrayList(); private List> feedTensors = new ArrayList>(); private List fetchNames = new ArrayList(); - private List> fetchTensors = null; + private List> fetchTensors = new ArrayList>(); // Mutable state. private RunStats runStats; diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 726a8f692f5b6eb8392efadf136aecd890d7f5eb..1b85c260c0ce6a4a7e772b07aa5d639105232f5f 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -68,6 +68,10 @@ py_library( srcs = ["python/utils/losses.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:nn", ], ) @@ -82,7 +86,11 @@ py_test( ], deps = [ ":losses", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", "//third_party/py/numpy", ], ) @@ -94,13 +102,30 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":gen_model_ops_py", "//tensorflow/contrib/boosted_trees:batch_ops_utils_py", "//tensorflow/contrib/boosted_trees:boosted_trees_ops_py", "//tensorflow/contrib/boosted_trees/lib:categorical_split_handler", "//tensorflow/contrib/boosted_trees/lib:ordinal_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", "//tensorflow/contrib/stateless", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:summary", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/feature_column", ], ) @@ -116,10 +141,19 @@ py_test( deps = [ ":gbdt_batch", ":losses", + ":model_ops_py", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/learn", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_test_lib", - "//third_party/py/numpy", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variables", ], ) @@ -138,8 +172,6 @@ py_test( ":prediction_ops_py", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", @@ -181,6 +213,9 @@ py_test( deps = [ ":quantile_ops_py", "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", @@ -233,7 +268,6 @@ py_test( "nomac", # b/63258195 ], deps = [ - ":boosted_trees_ops_loader", ":model_ops_py", ":training_ops_py", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", @@ -268,9 +302,8 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:resources", + "//tensorflow/python:errors", + "//tensorflow/python:platform", ], ) @@ -309,21 +342,17 @@ tf_custom_op_py_library( deps = [ ":boosted_trees_ops_loader", ":gen_model_ops_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:resources", + "//tensorflow/python:training", ], ) tf_kernel_library( name = "model_ops_kernels", - srcs = [ - "kernels/model_ops.cc", - ], + srcs = ["kernels/model_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", - "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", @@ -387,23 +416,17 @@ tf_custom_op_py_library( deps = [ ":boosted_trees_ops_loader", ":gen_split_handler_ops_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework_for_generated_wrappers", ], ) tf_kernel_library( name = "split_handler_ops_kernels", - srcs = [ - "kernels/split_handler_ops.cc", - ], + srcs = ["kernels/split_handler_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:feature-column-handlers", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", ], alwayslink = 1, ) @@ -435,25 +458,21 @@ tf_custom_op_py_library( deps = [ ":boosted_trees_ops_loader", ":gen_training_ops_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework_for_generated_wrappers", ], ) tf_kernel_library( name = "training_ops_kernels", - srcs = [ - "kernels/training_ops.cc", - ], + srcs = ["kernels/training_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", + "//tensorflow/contrib/boosted_trees/lib:weighted_quantiles", "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", + "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc", - "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", - "//tensorflow/core:framework", + "//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource", "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", ], alwayslink = 1, ) @@ -490,9 +509,7 @@ tf_custom_op_py_library( tf_kernel_library( name = "prediction_ops_kernels", - srcs = [ - "kernels/prediction_ops.cc", - ], + srcs = ["kernels/prediction_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:example_partitioner", "//tensorflow/contrib/boosted_trees/lib:models", @@ -500,7 +517,6 @@ tf_kernel_library( "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", ], @@ -532,26 +548,22 @@ tf_custom_op_py_library( ":batch_ops_utils_py", ":boosted_trees_ops_loader", ":gen_quantile_ops_py_wrap", - "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", ], ) tf_kernel_library( name = "quantile_ops_kernels", - srcs = [ - "kernels/quantile_ops.cc", - ], + srcs = ["kernels/quantile_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", "//tensorflow/contrib/boosted_trees/lib:weighted_quantiles", "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc", "//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource", "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", ], alwayslink = 1, ) @@ -581,8 +593,6 @@ tf_custom_op_py_library( ":batch_ops_utils_py", ":boosted_trees_ops_loader", ":gen_stats_accumulator_ops_py_wrap", - "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:resources", "//tensorflow/python:training", @@ -591,13 +601,10 @@ tf_custom_op_py_library( tf_kernel_library( name = "stats_accumulator_ops_kernels", - srcs = [ - "kernels/stats_accumulator_ops.cc", - ], + srcs = ["kernels/stats_accumulator_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", "//tensorflow/contrib/boosted_trees/resources:stamped_resource", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", ], alwayslink = 1, @@ -609,7 +616,12 @@ py_library( name = "boosted_trees_pip", deps = [ ":init_py", + "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/boosted_trees/estimator_batch:custom_export_strategy", "//tensorflow/contrib/boosted_trees/estimator_batch:init_py", + "//tensorflow/contrib/boosted_trees/estimator_batch:trainer_hooks", + "//tensorflow/contrib/boosted_trees/lib:categorical_split_handler", + "//tensorflow/contrib/boosted_trees/lib:ordinal_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_py", diff --git a/tensorflow/contrib/boosted_trees/README.md b/tensorflow/contrib/boosted_trees/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9ce700f1a19fde3f5b07748fd6768e9e8e336c8a --- /dev/null +++ b/tensorflow/contrib/boosted_trees/README.md @@ -0,0 +1,11 @@ +# TF Boosted Trees (TFBT) + +TF Boosted trees is an implementation of a gradient boosting algorithm with +trees used as week learners. + +## Examples +Folder "examples" demonstrates how TFBT estimators can be used for various +problems. Namely, it contains: +* binary_mnist.py - an example on how to use TFBT for binary classification. +* mnist.py - a multiclass example. +* boston.py - a regression example. \ No newline at end of file diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index f9e186788f6832b292a690d8d7b04e2f4edd584e..d0ee1fd60d0b62395f6638ab3d67e6fe95ae8331 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -27,13 +27,6 @@ py_library( "__init__.py", ], srcs_version = "PY2AND3", - deps = [ - "custom_export_strategy", - ":custom_loss_head", - ":estimator", - ":model", - ":trainer_hooks", - ], ) py_library( @@ -41,7 +34,12 @@ py_library( srcs = ["model.py"], srcs_version = "PY2AND3", deps = [ + ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/boosted_trees:model_ops_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", ], ) @@ -51,6 +49,10 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/learn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", ], ) @@ -61,6 +63,15 @@ py_test( srcs_version = "PY2AND3", deps = [ ":trainer_hooks", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", ], ) @@ -69,6 +80,10 @@ py_library( srcs = ["custom_loss_head.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/learn", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", ], ) @@ -82,6 +97,11 @@ py_library( "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_py", "//tensorflow/contrib/decision_trees/proto:generic_tree_model_py", "//tensorflow/contrib/learn", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", ], ) @@ -92,8 +112,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":custom_export_strategy", - "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_py", - "//tensorflow/contrib/decision_trees/proto:generic_tree_model_py", + "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", ], ) @@ -103,6 +124,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":model", - ":trainer_hooks", + "//tensorflow/contrib/learn", ], ) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 7773125c16772fe37369b11532c7f42df3ce166f..a800c3ddc7954133652d53b8fa381d4f1b3b5d40 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -96,7 +96,8 @@ def make_custom_export_strategy(name, def convert_to_universal_format(dtec, sorted_feature_names, num_dense, num_sparse_float, - num_sparse_int): + num_sparse_int, + feature_name_to_proto=None): """Convert GTFlow trees to universal format.""" del num_sparse_int # unused. model_and_features = generic_tree_model_pb2.ModelAndFeatures() @@ -104,7 +105,11 @@ def convert_to_universal_format(dtec, sorted_feature_names, # feature is processed before it's fed to the model (e.g. bucketing # information). As of now, this serves as a list of features the model uses. for feature_name in sorted_feature_names: - model_and_features.features[feature_name].SetInParent() + if not feature_name_to_proto: + model_and_features.features[feature_name].SetInParent() + else: + model_and_features.features[feature_name].CopyFrom( + feature_name_to_proto[feature_name]) model = model_and_features.model model.ensemble.summation_combination_technique.SetInParent() for tree_idx in range(len(dtec.trees)): diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 8cda5c8f2b14f2ec3cfe3702e38b81803dd075f7..c6455a7ea3d18eb358edee034cee58b2bed21024 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -93,7 +93,7 @@ def model_builder(features, labels, mode, params, config): learner_config=learner_config, feature_columns=feature_columns, logits_dimension=head.logits_dimension, - features=features) + features=training_features) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] diff --git a/tensorflow/contrib/boosted_trees/examples/binary_mnist.py b/tensorflow/contrib/boosted_trees/examples/binary_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..c003b1de66fff63b3faf856f17cdce8c877922ba --- /dev/null +++ b/tensorflow/contrib/boosted_trees/examples/binary_mnist.py @@ -0,0 +1,169 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Demonstrates multiclass MNIST TF Boosted trees example. + + This example demonstrates how to run experiments with TF Boosted Trees on + a binary dataset. We use digits 4 and 9 from the original MNIST dataset. + + Example Usage: + python tensorflow/contrib/boosted_trees/examples/binary_mnist.py \ + --output_dir="/tmp/binary_mnist" --depth=4 --learning_rate=0.3 \ + --batch_size=10761 --examples_per_layer=10761 --eval_batch_size=1030 \ + --num_eval_steps=1 --num_trees=10 --l2=1 --vmodule=training_ops=1 \ + + When training is done, accuracy on eval data is reported. Point tensorboard + to the directory for the run to see how the training progresses: + + tensorboard --logdir=/tmp/binary_mnist + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.learn import learn_runner + + +def get_input_fn(data, + batch_size, + capacity=10000, + min_after_dequeue=3000): + """Input function over MNIST data.""" + # Keep only 4 and 9 digits. + ids = np.where((data.labels == 4) | (data.labels == 9)) + images = data.images[ids] + labels = data.labels[ids] + # Make digit 4 label 1, 9 is 0. + labels = labels == 4 + + def _input_fn(): + """Prepare features and labels.""" + images_batch, labels_batch = tf.train.shuffle_batch( + tensors=[images, + labels.astype(np.int32)], + batch_size=batch_size, + capacity=capacity, + min_after_dequeue=min_after_dequeue, + enqueue_many=True, + num_threads=4) + features_map = {"images": images_batch} + return features_map, labels_batch + + return _input_fn + + +# Main config - creates a TF Boosted Trees Estimator based on flags. +def _get_tfbt(output_dir): + """Configures TF Boosted Trees estimator based on flags.""" + learner_config = learner_pb2.LearnerConfig() + + learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate + learner_config.regularization.l1 = 0.0 + learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer + learner_config.constraints.max_tree_depth = FLAGS.depth + + growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER + learner_config.growing_mode = growing_mode + run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) + + # Create a TF Boosted trees estimator that can take in custom loss. + estimator = GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + examples_per_layer=FLAGS.examples_per_layer, + model_dir=output_dir, + num_trees=FLAGS.num_trees, + center_bias=False, + config=run_config) + return estimator + + +def _make_experiment_fn(output_dir): + """Creates experiment for gradient boosted decision trees.""" + data = tf.contrib.learn.datasets.mnist.load_mnist() + train_input_fn = get_input_fn(data.train, FLAGS.batch_size) + eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size) + + return tf.contrib.learn.Experiment( + estimator=_get_tfbt(output_dir), + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + train_steps=None, + eval_steps=FLAGS.num_eval_steps, + eval_metrics=None) + + +def main(unused_argv): + learn_runner.run( + experiment_fn=_make_experiment_fn, + output_dir=FLAGS.output_dir, + schedule="train_and_evaluate") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + parser = argparse.ArgumentParser() + # Define the list of flags that users can change. + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Choose the dir for the output.") + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="The batch size for reading data.") + parser.add_argument( + "--eval_batch_size", + type=int, + default=1000, + help="Size of the batch for eval.") + parser.add_argument( + "--num_eval_steps", + type=int, + default=1, + help="The number of steps to run evaluation for.") + # Flags for gradient boosted trees config. + parser.add_argument( + "--depth", type=int, default=4, help="Maximum depth of weak learners.") + parser.add_argument( + "--l2", type=float, default=1.0, help="l2 regularization per batch.") + parser.add_argument( + "--learning_rate", + type=float, + default=0.1, + help="Learning rate (shrinkage weight) with which each new tree is added." + ) + parser.add_argument( + "--examples_per_layer", + type=int, + default=1000, + help="Number of examples to accumulate stats for per layer.") + parser.add_argument( + "--num_trees", + type=int, + default=None, + required=True, + help="Number of trees to grow before stopping.") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0a3c4912b82aba88e2f8f1b97a227c894ee2ae --- /dev/null +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -0,0 +1,153 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Demonstrates a regression on Boston housing data. + + This example demonstrates how to run experiments with TF Boosted Trees on + a regression dataset. We split all the data into 20% test and 80% train, + and are using l2 loss and l2 regularization. + + Example Usage: + + python tensorflow/contrib/boosted_trees/examples/boston.py \ + --batch_size=404 --output_dir="/tmp/boston" --depth=4 --learning_rate=0.1 \ + --num_eval_steps=1 --num_trees=500 --l2=4 \ + --vmodule=training_ops=1 + + When training is done, mean squared error on eval data is reported. + Point tensorboard to the directory for the run to see how the training + progresses: + + tensorboard --logdir=/tmp/boston + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +import tensorflow as tf +from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeRegressor +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.layers.python.layers import feature_column +from tensorflow.contrib.learn import learn_runner + +_BOSTON_NUM_FEATURES = 13 + + +# Main config - creates a TF Boosted Trees Estimator based on flags. +def _get_tfbt(output_dir, feature_cols): + """Configures TF Boosted Trees estimator based on flags.""" + learner_config = learner_pb2.LearnerConfig() + + learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate + learner_config.regularization.l1 = 0.0 + # Set the regularization per instance in such a way that + # regularization for the full training data is equal to l2 flag. + learner_config.regularization.l2 = FLAGS.l2 / FLAGS.batch_size + learner_config.constraints.max_tree_depth = FLAGS.depth + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + + run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) + + # Create a TF Boosted trees regression estimator. + estimator = GradientBoostedDecisionTreeRegressor( + learner_config=learner_config, + # For the WHOLE_TREE strategy, set the examples_per_layer to be equal to + # batch size. + examples_per_layer=FLAGS.batch_size, + feature_columns=feature_cols, + label_dimension=1, + model_dir=output_dir, + num_trees=FLAGS.num_trees, + center_bias=False, + config=run_config) + return estimator + + +def _make_experiment_fn(output_dir): + """Creates experiment for gradient boosted decision trees.""" + (x_train, y_train), (x_test, + y_test) = tf.keras.datasets.boston_housing.load_data() + + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_train}, + y=y_train, + batch_size=FLAGS.batch_size, + num_epochs=None, + shuffle=True) + + eval_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) + + feature_columns = [ + feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES) + ] + + return tf.contrib.learn.Experiment( + estimator=_get_tfbt(output_dir, feature_columns), + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + train_steps=None, + eval_steps=FLAGS.num_eval_steps, + eval_metrics=None) + + +def main(unused_argv): + learn_runner.run( + experiment_fn=_make_experiment_fn, + output_dir=FLAGS.output_dir, + schedule="train_and_evaluate") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + parser = argparse.ArgumentParser() + # Define the list of flags that users can change. + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="The batch size for reading data.") + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Choose the dir for the output.") + parser.add_argument( + "--num_eval_steps", + type=int, + default=1, + help="The number of steps to run evaluation for.") + # Flags for gradient boosted trees config. + parser.add_argument( + "--depth", type=int, default=4, help="Maximum depth of weak learners.") + parser.add_argument( + "--l2", type=float, default=1.0, help="l2 regularization per batch.") + parser.add_argument( + "--learning_rate", + type=float, + default=0.1, + help="Learning rate (shrinkage weight) with which each new tree is added." + ) + parser.add_argument( + "--num_trees", + type=int, + default=None, + required=True, + help="Number of trees to grow before stopping.") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/boosted_trees/examples/mnist.py b/tensorflow/contrib/boosted_trees/examples/mnist.py index 7e34d2f2d36e1022845fa63ba44b2df7d8a2cf55..a3b1cb5154644e1a97633d429aae8ae18ecdaa2b 100644 --- a/tensorflow/contrib/boosted_trees/examples/mnist.py +++ b/tensorflow/contrib/boosted_trees/examples/mnist.py @@ -129,8 +129,8 @@ def _get_tfbt(output_dir): def _make_experiment_fn(output_dir): """Creates experiment for gradient boosted decision trees.""" data = tf.contrib.learn.datasets.mnist.load_mnist() - train_input_fn = get_input_fn(data.train, batch_size=256) - eval_input_fn = get_input_fn(data.validation, batch_size=5000) + train_input_fn = get_input_fn(data.train, FLAGS.batch_size) + eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size) return tf.contrib.learn.Experiment( estimator=_get_tfbt(output_dir), diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index f4ad99f779e0d7fcf207934d77776548214371c1..4b5d5ba0de6c3995ee2da7a44ab0ba099cbf1b35 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -15,7 +15,6 @@ #include #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h" -#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -46,9 +45,8 @@ class CreateTreeEnsembleVariableOp : public OpKernel { OP_REQUIRES_OK(context, context->input("tree_ensemble_config", &tree_ensemble_config_t)); auto* result = new boosted_trees::models::DecisionTreeEnsembleResource(); - result->set_stamp(stamp_token); - if (!ParseProtoUnlimited(result->mutable_decision_tree_ensemble(), - tree_ensemble_config_t->scalar()())) { + if (!result->InitFromSerialized(tree_ensemble_config_t->scalar()(), + stamp_token)) { result->Unref(); OP_REQUIRES(context, false, errors::InvalidArgument( "Unable to parse tree ensemble config.")); @@ -70,17 +68,15 @@ class TreeEnsembleStampTokenOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); + &ensemble_resource)); + tf_shared_lock l(*ensemble_resource->get_mutex()); + core::ScopedUnref unref_me(ensemble_resource); Tensor* output_stamp_token_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output_stamp_token_t)); - output_stamp_token_t->scalar()() = - decision_tree_ensemble_resource->stamp(); + output_stamp_token_t->scalar()() = ensemble_resource->stamp(); } }; @@ -91,23 +87,20 @@ class TreeEnsembleSerializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); + &ensemble_resource)); + tf_shared_lock l(*ensemble_resource->get_mutex()); + core::ScopedUnref unref_me(ensemble_resource); Tensor* output_stamp_token_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output_stamp_token_t)); - output_stamp_token_t->scalar()() = - decision_tree_ensemble_resource->stamp(); + output_stamp_token_t->scalar()() = ensemble_resource->stamp(); Tensor* output_config_t = nullptr; OP_REQUIRES_OK( context, context->allocate_output(1, TensorShape(), &output_config_t)); output_config_t->scalar()() = - decision_tree_ensemble_resource->decision_tree_ensemble() - .SerializeAsString(); + ensemble_resource->SerializeAsString(); } }; @@ -118,12 +111,11 @@ class TreeEnsembleDeserializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); + &ensemble_resource)); + mutex_lock l(*ensemble_resource->get_mutex()); + core::ScopedUnref unref_me(ensemble_resource); // Get the stamp token. const Tensor* stamp_token_t; @@ -135,13 +127,11 @@ class TreeEnsembleDeserializeOp : public OpKernel { OP_REQUIRES_OK(context, context->input("tree_ensemble_config", &tree_ensemble_config_t)); // Deallocate all the previous objects on the resource. - decision_tree_ensemble_resource->Reset(); - decision_tree_ensemble_resource->set_stamp(stamp_token); - boosted_trees::trees::DecisionTreeEnsembleConfig* config = - decision_tree_ensemble_resource->mutable_decision_tree_ensemble(); + ensemble_resource->Reset(); OP_REQUIRES( context, - ParseProtoUnlimited(config, tree_ensemble_config_t->scalar()()), + ensemble_resource->InitFromSerialized( + tree_ensemble_config_t->scalar()(), stamp_token), errors::InvalidArgument("Unable to parse tree ensemble config.")); } }; diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index 54b0c7842a4e6a7f433bf6e93762559bc4d9faf2..766982b4f2023310e6046619939f83bef63b0302 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -59,8 +59,27 @@ const char* kApplyDropoutAttributeName = "apply_dropout"; const char* kApplyAveragingAttributeName = "apply_averaging"; const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights"; const char* kPredictionsTensorName = "predictions"; -const char* kNoDropoutPredictionsTensorName = "no_dropout_predictions"; + +void CalculateTreesToInclude( + const boosted_trees::trees::DecisionTreeEnsembleConfig& config, + const std::vector& trees_to_drop, const int32 num_trees, + const bool only_finalized, std::vector* trees_to_include) { + trees_to_include->reserve(num_trees - trees_to_drop.size()); + + int32 index = 0; + // This assumes that trees_to_drop is a sorted list of tree ids. + for (int32 tree = 0; tree < num_trees; ++tree) { + if ((!trees_to_drop.empty() && index < trees_to_drop.size() && + trees_to_drop[index] == tree) || + (only_finalized && config.tree_metadata_size() > 0 && + !config.tree_metadata(tree).is_finalized())) { + ++index; + continue; + } + trees_to_include->push_back(tree); + } } +} // namespace class GradientTreesPredictionOp : public OpKernel { public: @@ -136,24 +155,23 @@ class GradientTreesPredictionOp : public OpKernel { } void Compute(OpKernelContext* const context) override { - DecisionTreeEnsembleResource* decision_tree_ensemble_resource; + DecisionTreeEnsembleResource* ensemble_resource; // Gets the resource. Grabs the mutex but releases it. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); + &ensemble_resource)); // Release the reference to the resource once we're done using it. - core::ScopedUnref unref_me(decision_tree_ensemble_resource); + core::ScopedUnref unref_me(ensemble_resource); if (use_locking_) { - tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); - DoCompute(context, decision_tree_ensemble_resource); + tf_shared_lock l(*ensemble_resource->get_mutex()); + DoCompute(context, ensemble_resource); } else { - DoCompute(context, decision_tree_ensemble_resource); + DoCompute(context, ensemble_resource); } } private: - void DoCompute( - OpKernelContext* context, - DecisionTreeEnsembleResource* decision_tree_ensemble_resource) { + void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource) { // Read dense float features list; OpInputList dense_float_features_list; OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures( @@ -205,41 +223,35 @@ class GradientTreesPredictionOp : public OpKernel { // Do dropout if needed. if (apply_dropout_ && has_dropout_) { - // Read in seed + // Read in seed and cast to uint64. const Tensor* seed_t; OP_REQUIRES_OK(context, context->input(kSeedTensorName, &seed_t)); OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_t->shape()), errors::InvalidArgument("Seed must be a scalar.")); - - // Cast seed to uint64. const uint64 seed = seed_t->scalar()(); - std::vector weights; - for (const float weight : - decision_tree_ensemble_resource->decision_tree_ensemble() - .tree_weights()) { - weights.push_back(weight); - } - std::unordered_set trees_not_to_drop; if (center_bias_) { trees_not_to_drop.insert(0); } - if (decision_tree_ensemble_resource->decision_tree_ensemble() - .has_growing_metadata()) { + if (ensemble_resource->decision_tree_ensemble().has_growing_metadata()) { // We are in batch mode, the last tree is the tree that is being built, // we can't drop it during dropout. - const int32 current_tree = - decision_tree_ensemble_resource->decision_tree_ensemble() - .trees_size() - - 1; - trees_not_to_drop.insert(current_tree); + trees_not_to_drop.insert(ensemble_resource->num_trees() - 1); } + const std::vector weights = ensemble_resource->GetTreeWeights(); OP_REQUIRES_OK(context, DropoutUtils::DropOutTrees( seed, dropout_config_, trees_not_to_drop, weights, &dropped_trees, &original_weights)); } + // Prepare the list of trees to include in the prediction. + std::vector trees_to_include; + CalculateTreesToInclude( + ensemble_resource->decision_tree_ensemble(), dropped_trees, + ensemble_resource->decision_tree_ensemble().trees_size(), + only_finalized_trees_, &trees_to_include); + // Allocate output predictions matrix. Tensor* output_predictions_t = nullptr; OP_REQUIRES_OK( @@ -248,22 +260,13 @@ class GradientTreesPredictionOp : public OpKernel { &output_predictions_t)); auto output_predictions = output_predictions_t->matrix(); - Tensor* output_no_dropout_predictions_t = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output(kNoDropoutPredictionsTensorName, - {batch_size, prediction_vector_size_}, - &output_no_dropout_predictions_t)); - auto output_no_dropout_predictions = - output_no_dropout_predictions_t->matrix(); - // Run predictor. thread::ThreadPool* const worker_threads = context->device()->tensorflow_cpu_worker_threads()->workers; if (apply_averaging_) { DecisionTreeEnsembleConfig adjusted = - decision_tree_ensemble_resource->decision_tree_ensemble(); - + ensemble_resource->decision_tree_ensemble(); const int start_averaging = std::max( 0.0, averaging_config_.config_case() == @@ -271,21 +274,18 @@ class GradientTreesPredictionOp : public OpKernel { ? adjusted.trees_size() - averaging_config_.average_last_n_trees() : adjusted.trees_size() * (1.0 - averaging_config_.average_last_percent_trees())); - const int num_ensembles = adjusted.trees_size() - start_averaging; for (int i = start_averaging; i < adjusted.trees_size(); ++i) { float weight = adjusted.tree_weights(i); adjusted.mutable_tree_weights()->Set( i, weight * (num_ensembles - i + start_averaging) / num_ensembles); } - MultipleAdditiveTrees::Predict( - adjusted, only_finalized_trees_, dropped_trees, batch_features, - worker_threads, output_predictions, output_no_dropout_predictions); + MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features, + worker_threads, output_predictions); } else { MultipleAdditiveTrees::Predict( - decision_tree_ensemble_resource->decision_tree_ensemble(), - only_finalized_trees_, dropped_trees, batch_features, worker_threads, - output_predictions, output_no_dropout_predictions); + ensemble_resource->decision_tree_ensemble(), trees_to_include, + batch_features, worker_threads, output_predictions); } // Output dropped trees and original weights. @@ -327,37 +327,32 @@ class GradientTreesPartitionExamplesOp : public OpKernel { } void Compute(OpKernelContext* const context) override { - DecisionTreeEnsembleResource* decision_tree_ensemble_resource; + DecisionTreeEnsembleResource* ensemble_resource; // Gets the resource. Grabs the mutex but releases it. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); + &ensemble_resource)); // Release the reference to the resource once we're done using it. - core::ScopedUnref unref_me(decision_tree_ensemble_resource); + core::ScopedUnref unref_me(ensemble_resource); if (use_locking_) { - tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); - DoCompute(context, decision_tree_ensemble_resource); + tf_shared_lock l(*ensemble_resource->get_mutex()); + DoCompute(context, ensemble_resource); } else { - DoCompute(context, decision_tree_ensemble_resource); + DoCompute(context, ensemble_resource); } } private: - void DoCompute( - OpKernelContext* context, - DecisionTreeEnsembleResource* decision_tree_ensemble_resource) { + void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource) { // The last non-finalized tree in the ensemble is by convention the // one to partition on. If no such tree exists, a nodeless tree is // created. - const auto& tree_ensemble = - decision_tree_ensemble_resource->decision_tree_ensemble(); - boosted_trees::trees::DecisionTreeConfig empy_tree_config; - const boosted_trees::trees::DecisionTreeConfig* tree_config = - &empy_tree_config; - auto num_trees = tree_ensemble.trees_size(); - if (num_trees > 0 && - !tree_ensemble.tree_metadata(num_trees - 1).is_finalized()) { - tree_config = &tree_ensemble.trees(num_trees - 1); - } + boosted_trees::trees::DecisionTreeConfig empty_tree_config; + const boosted_trees::trees::DecisionTreeConfig& tree_config = + (ensemble_resource->num_trees() <= 0 || + ensemble_resource->LastTreeMetadata()->is_finalized()) + ? empty_tree_config + : *ensemble_resource->LastTree(); // Read dense float features list; OpInputList dense_float_features_list; @@ -412,7 +407,7 @@ class GradientTreesPartitionExamplesOp : public OpKernel { thread::ThreadPool* const worker_threads = context->device()->tensorflow_cpu_worker_threads()->workers; learner::ExamplePartitioner::PartitionExamples( - *tree_config, batch_features, worker_threads->NumThreads(), + tree_config, batch_features, worker_threads->NumThreads(), worker_threads, partition_ids_t->vec().data()); } diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 2c14b042925dd393e51ed1bf424c320909784fef..4c56718f1bbc0b42c1f5454ddfafe6ccd8c35c2c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -24,14 +24,13 @@ using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; namespace boosted_trees { -using boosted_trees::trees::DecisionTreeEnsembleConfig; +namespace { + +using boosted_trees::learner::LearningRateConfig; +using boosted_trees::trees::Leaf; using boosted_trees::trees::TreeNode; using boosted_trees::trees::TreeNodeMetadata; using boosted_trees::utils::DropoutUtils; -using boosted_trees::learner::LearningRateConfig; -using boosted_trees::trees::Leaf; - -namespace { // SplitCandidate holds the split candidate node along with the stats. struct SplitCandidate { @@ -187,12 +186,11 @@ class CenterTreeEnsembleBiasOp : public OpKernel { void Compute(OpKernelContext* const context) override { // Get decision tree ensemble. - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); + &ensemble_resource)); + core::ScopedUnref unref_me(ensemble_resource); + mutex_lock l(*ensemble_resource->get_mutex()); // Get the stamp token. const Tensor* stamp_token_t; @@ -201,7 +199,7 @@ class CenterTreeEnsembleBiasOp : public OpKernel { // Only the Chief should run this Op and it is guaranteed to be in // a consistent state so the stamps must always match. - CHECK(decision_tree_ensemble_resource->is_stamp_valid(stamp_token)); + CHECK(ensemble_resource->is_stamp_valid(stamp_token)); // Get the next stamp token. const Tensor* next_stamp_token_t; @@ -221,11 +219,10 @@ class CenterTreeEnsembleBiasOp : public OpKernel { auto delta_updates = delta_updates_t->vec(); // Update the ensemble stamp. - decision_tree_ensemble_resource->set_stamp(next_stamp_token); + ensemble_resource->set_stamp(next_stamp_token); // Get the bias. - boosted_trees::trees::Leaf* bias = - RetrieveBias(decision_tree_ensemble_resource); + boosted_trees::trees::Leaf* const bias = RetrieveBias(ensemble_resource); CHECK(bias->has_vector()); OP_REQUIRES( context, @@ -259,35 +256,26 @@ class CenterTreeEnsembleBiasOp : public OpKernel { private: // Helper method to retrieve the bias from the tree ensemble. boosted_trees::trees::Leaf* RetrieveBias( - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource) { - boosted_trees::trees::DecisionTreeEnsembleConfig* ensemble_config = - decision_tree_ensemble_resource->mutable_decision_tree_ensemble(); - const auto num_trees = ensemble_config->trees_size(); - CHECK(num_trees == ensemble_config->tree_metadata_size() && - num_trees == ensemble_config->tree_weights_size()); + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) { + const int32 num_trees = ensemble_resource->num_trees(); if (num_trees <= 0) { - ensemble_config->mutable_growing_metadata()->set_num_trees_attempted(1); - ensemble_config->mutable_growing_metadata()->set_num_layers_attempted(1); // Add a new bias leaf. - boosted_trees::trees::DecisionTreeConfig* tree_config = - ensemble_config->add_trees(); - auto* leaf = tree_config->add_nodes()->mutable_leaf(); + ensemble_resource->IncrementAttempts(); + boosted_trees::trees::DecisionTreeConfig* const tree_config = + ensemble_resource->AddNewTree(1.0); + auto* const leaf = tree_config->add_nodes()->mutable_leaf(); for (size_t idx = 0; idx + 1 < learner_config_.num_classes(); ++idx) { - leaf->mutable_vector()->add_value(0); + leaf->mutable_vector()->add_value(0.0); } - ensemble_config->add_tree_weights(1.0); - boosted_trees::trees::DecisionTreeMetadata* tree_metadata = - ensemble_config->add_tree_metadata(); - tree_metadata->set_num_layers_grown(1); - tree_metadata->set_is_finalized(true); + ensemble_resource->LastTreeMetadata()->set_is_finalized(true); return leaf; } else if (num_trees == 1) { - // Update the existing bias. - CHECK_EQ(ensemble_config->trees(0).nodes_size(), 1); - auto* node = ensemble_config->mutable_trees(0)->mutable_nodes(0); - CHECK(node->node_case() == TreeNode::kLeaf); - return node->mutable_leaf(); + // Confirms that the only tree is a bias and returns its leaf. + boosted_trees::trees::DecisionTreeConfig* const tree_config = + ensemble_resource->LastTree(); + CHECK_EQ(tree_config->nodes_size(), 1); + CHECK_EQ(tree_config->nodes(0).node_case(), TreeNode::kLeaf); + return tree_config->mutable_nodes(0)->mutable_leaf(); } else { LOG(FATAL) << "Unable to center bias on an already grown ensemble"; } @@ -331,12 +319,11 @@ class GrowTreeEnsembleOp : public OpKernel { void Compute(OpKernelContext* const context) override { // Get decision tree ensemble. - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); + &ensemble_resource)); + core::ScopedUnref unref_me(ensemble_resource); + mutex_lock l(*ensemble_resource->get_mutex()); // Get the stamp token. const Tensor* stamp_token_t; @@ -345,7 +332,7 @@ class GrowTreeEnsembleOp : public OpKernel { // Only the Chief should run this Op and it is guaranteed to be in // a consistent state so the stamps must always match. - CHECK(decision_tree_ensemble_resource->is_stamp_valid(stamp_token)); + CHECK(ensemble_resource->is_stamp_valid(stamp_token)); // Get the next stamp token. const Tensor* next_stamp_token_t; @@ -356,7 +343,7 @@ class GrowTreeEnsembleOp : public OpKernel { // Update the ensemble stamp regardless of whether a layer // or tree is actually grown. - decision_tree_ensemble_resource->set_stamp(next_stamp_token); + ensemble_resource->set_stamp(next_stamp_token); // Read the learning_rate. const Tensor* learning_rate_t; @@ -378,16 +365,8 @@ class GrowTreeEnsembleOp : public OpKernel { OP_REQUIRES_OK(context, context->input_list("gains", &gains_list)); OP_REQUIRES_OK(context, context->input_list("splits", &splits_list)); - boosted_trees::trees::DecisionTreeEnsembleConfig* ensemble_config = - decision_tree_ensemble_resource->mutable_decision_tree_ensemble(); - ensemble_config->mutable_growing_metadata()->set_num_layers_attempted( - ensemble_config->growing_metadata().num_layers_attempted() + 1); - const int num_trees = ensemble_config->trees_size(); - if (num_trees <= 0 || - ensemble_config->tree_metadata(num_trees - 1).is_finalized()) { - ensemble_config->mutable_growing_metadata()->set_num_trees_attempted( - ensemble_config->growing_metadata().num_trees_attempted() + 1); - } + // Increment attempt stats. + ensemble_resource->IncrementAttempts(); // Find best splits for each active partition. std::map best_splits; @@ -400,14 +379,12 @@ class GrowTreeEnsembleOp : public OpKernel { return; } - // Update and retrieve the growable tree with its metadata. - boosted_trees::trees::DecisionTreeConfig* tree_config; - boosted_trees::trees::DecisionTreeMetadata* tree_metadata; - - // Updates the tree. If the tree is fully built and dropout was applied, it - // also adjusts the weights of dropped and the last tree. - std::tie(tree_config, tree_metadata) = UpdateAndRetrieveGrowableTree( - decision_tree_ensemble_resource, learning_rate, dropout_seed); + // Update and retrieve the growable tree. + // If the tree is fully built and dropout was applied, it also adjusts the + // weights of dropped and the last tree. + boosted_trees::trees::DecisionTreeConfig* const tree_config = + UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate, + dropout_seed); // Split tree nodes. for (auto& split_entry : best_splits) { @@ -417,16 +394,14 @@ class GrowTreeEnsembleOp : public OpKernel { // Post-prune finalized tree if needed. if (learner_config_.pruning_mode() == boosted_trees::learner::LearnerConfig::POST_PRUNE && - tree_metadata->is_finalized()) { + ensemble_resource->LastTreeMetadata()->is_finalized()) { VLOG(2) << "Post-pruning finalized tree."; PruneTree(tree_config); // If after post-pruning the whole tree has no gain, remove the tree // altogether from the ensemble. if (tree_config->nodes_size() <= 0) { - ensemble_config->mutable_trees()->RemoveLast(); - ensemble_config->mutable_tree_weights()->RemoveLast(); - ensemble_config->mutable_tree_metadata()->RemoveLast(); + ensemble_resource->RemoveLastTree(); } } } @@ -471,111 +446,88 @@ class GrowTreeEnsembleOp : public OpKernel { } void UpdateTreeWeightsIfDropout( - boosted_trees::trees::DecisionTreeEnsembleConfig* ensemble_config, - boosted_trees::trees::DecisionTreeMetadata* tree_metadata, + boosted_trees::models::DecisionTreeEnsembleResource* const + ensemble_resource, const uint64 dropout_seed) { // It is possible that the tree was built with dropout. If it is the case, - // we need to adjust the tree weight. - if (dropout_was_applied_ && tree_metadata->is_finalized()) { - const int32 num_trees = ensemble_config->trees_size(); - - std::vector dropped_trees; - // Since only chief builds the trees, we are sure that the other tree - // weights didn't change. - std::vector weights; - weights.reserve(num_trees); - std::vector num_updates; - num_updates.reserve(num_trees); - for (int i = 0; i < num_trees; ++i) { - weights.push_back(ensemble_config->tree_weights(i)); - num_updates.push_back( - ensemble_config->tree_metadata(i).num_tree_weight_updates()); - } + // we need to adjust the tree weight, or bail out. + if (!dropout_was_applied_ || + !ensemble_resource->LastTreeMetadata()->is_finalized()) { + return; + } + const int32 num_trees = ensemble_resource->num_trees(); - std::vector dropped_trees_weights; - // Based on seed, figure out what trees were dropped before. - std::unordered_set trees_not_to_drop; - if (center_bias_) { - trees_not_to_drop.insert(0); - } - // Last tree is the current tree that is built. - const int32 current_tree = num_trees - 1; - trees_not_to_drop.insert(current_tree); - - const auto dropout_status = DropoutUtils::DropOutTrees( - dropout_seed, dropout_config_, trees_not_to_drop, weights, - &dropped_trees, &dropped_trees_weights); - CHECK(dropout_status.ok()) - << "Can't figure out what trees were dropped out before, error is " - << dropout_status.error_message(); - - // Now we have dropped trees, update their weights and the current tree - // weight. - if (!dropped_trees.empty()) { - DropoutUtils::GetTreesWeightsForAddingTrees( - dropped_trees, dropped_trees_weights, current_tree, - 1 /* only 1 tree was added */, &weights, &num_updates); - - // Update the weights and num of updates for trees. - for (int i = 0; i < num_trees; ++i) { - ensemble_config->set_tree_weights(i, weights[i]); - ensemble_config->mutable_tree_metadata(i) - ->set_num_tree_weight_updates(num_updates[i]); - } + // Based on seed, figure out what trees were dropped before. + std::unordered_set trees_not_to_drop; + if (center_bias_) { + trees_not_to_drop.insert(0); + } + // Last tree is the current tree that is built. + const int32 current_tree = num_trees - 1; + trees_not_to_drop.insert(current_tree); + + // Since only chief builds the trees, we are sure that the other tree + // weights didn't change. + std::vector weights = ensemble_resource->GetTreeWeights(); + std::vector dropped_trees; + std::vector dropped_trees_weights; + const auto dropout_status = DropoutUtils::DropOutTrees( + dropout_seed, dropout_config_, trees_not_to_drop, weights, + &dropped_trees, &dropped_trees_weights); + CHECK(dropout_status.ok()) + << "Can't figure out what trees were dropped out before, error is " + << dropout_status.error_message(); + + // Now we have dropped trees, update their weights and the current tree + // weight. + if (!dropped_trees.empty()) { + std::vector increment_num_updates(num_trees, 0); + DropoutUtils::GetTreesWeightsForAddingTrees( + dropped_trees, dropped_trees_weights, current_tree, + 1 /* only 1 tree was added */, &weights, &increment_num_updates); + + // Update the weights and num of updates for trees. + for (int i = 0; i < num_trees; ++i) { + ensemble_resource->SetTreeWeight(i, weights[i], + increment_num_updates[i]); } } } - // Helper method to update and retrieve the growable tree which is by - // definition the last tree in the ensemble. - std::pair - UpdateAndRetrieveGrowableTree( - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource, - float learning_rate, const uint64 dropout_seed) { - boosted_trees::trees::DecisionTreeEnsembleConfig* ensemble_config = - decision_tree_ensemble_resource->mutable_decision_tree_ensemble(); - auto num_trees = ensemble_config->trees_size(); - CHECK(num_trees == ensemble_config->tree_metadata_size() && - num_trees == ensemble_config->tree_weights_size()); + // Helper method to update the growable tree which is by definition the last + // tree in the ensemble. + boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree( + boosted_trees::models::DecisionTreeEnsembleResource* const + ensemble_resource, + const float learning_rate, const uint64 dropout_seed) { + const auto num_trees = ensemble_resource->num_trees(); if (num_trees <= 0 || - ensemble_config->tree_metadata(num_trees - 1).is_finalized()) { + ensemble_resource->LastTreeMetadata()->is_finalized()) { // Create a new tree with a no-op leaf. - boosted_trees::trees::DecisionTreeConfig* tree_config = - ensemble_config->add_trees(); - ++num_trees; - VLOG(1) << "Adding layer 0 to tree " << num_trees - 1 - << " of ensemble of " << num_trees << " trees."; + boosted_trees::trees::DecisionTreeConfig* const tree_config = + ensemble_resource->AddNewTree(learning_rate); + VLOG(1) << "Adding layer #0 to tree #" << num_trees << " of ensemble of " + << num_trees + 1 << " trees."; tree_config->add_nodes()->mutable_leaf(); - ensemble_config->add_tree_weights(learning_rate); - boosted_trees::trees::DecisionTreeMetadata* tree_metadata = - ensemble_config->add_tree_metadata(); - tree_metadata->set_num_layers_grown(1); + boosted_trees::trees::DecisionTreeMetadata* const tree_metadata = + ensemble_resource->LastTreeMetadata(); tree_metadata->set_is_finalized( learner_config_.constraints().max_tree_depth() <= 1); tree_metadata->set_num_tree_weight_updates(1); - - UpdateTreeWeightsIfDropout(ensemble_config, tree_metadata, dropout_seed); - return std::make_pair(tree_config, tree_metadata); } else { // The growable tree is by definition the last tree in the ensemble. - boosted_trees::trees::DecisionTreeMetadata* tree_metadata = - ensemble_config->mutable_tree_metadata(num_trees - 1); - auto num_layers_grown = tree_metadata->num_layers_grown(); - VLOG(1) << "Adding layer " << num_layers_grown << " to tree " + boosted_trees::trees::DecisionTreeMetadata* const tree_metadata = + ensemble_resource->LastTreeMetadata(); + const auto new_num_layers = tree_metadata->num_layers_grown() + 1; + VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #" << num_trees - 1 << " of ensemble of " << num_trees << " trees."; // Update growable tree metadata. - ++num_layers_grown; - tree_metadata->set_num_layers_grown(num_layers_grown); + tree_metadata->set_num_layers_grown(new_num_layers); tree_metadata->set_is_finalized( - num_layers_grown >= learner_config_.constraints().max_tree_depth()); - auto* tree_config = ensemble_config->mutable_trees(num_trees - 1); - - UpdateTreeWeightsIfDropout(ensemble_config, tree_metadata, dropout_seed); - - return std::make_pair(tree_config, tree_metadata); + new_num_layers >= learner_config_.constraints().max_tree_depth()); } + UpdateTreeWeightsIfDropout(ensemble_resource, dropout_seed); + return ensemble_resource->LastTree(); } // Helper method to merge leaf weights as the tree is being grown. @@ -763,12 +715,11 @@ class TreeEnsembleStatsOp : public OpKernel { void Compute(OpKernelContext* const context) override { // Get decision tree ensemble. - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); - tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); + &ensemble_resource)); + core::ScopedUnref unref_me(ensemble_resource); + tf_shared_lock l(*ensemble_resource->get_mutex()); // Get the stamp token. const Tensor* stamp_token_t; @@ -777,9 +728,9 @@ class TreeEnsembleStatsOp : public OpKernel { // Only the Chief should run this Op and it is guaranteed to be in // a consistent state so the stamps must always match. - CHECK(decision_tree_ensemble_resource->is_stamp_valid(stamp_token)); + CHECK(ensemble_resource->is_stamp_valid(stamp_token)); const boosted_trees::trees::DecisionTreeEnsembleConfig& ensemble_config = - decision_tree_ensemble_resource->decision_tree_ensemble(); + ensemble_resource->decision_tree_ensemble(); // Set tree stats. Tensor* num_trees_t = nullptr; @@ -794,13 +745,13 @@ class TreeEnsembleStatsOp : public OpKernel { context->allocate_output("attempted_trees", TensorShape({}), &attempted_tree_t)); - int num_trees = ensemble_config.trees_size(); + const int num_trees = ensemble_resource->num_trees(); active_tree_t->scalar()() = num_trees; - if (num_trees > 0 && - !ensemble_config.tree_metadata(num_trees - 1).is_finalized()) { - --num_trees; - } - num_trees_t->scalar()() = num_trees; + num_trees_t->scalar()() = + (num_trees <= 0 || + ensemble_resource->LastTreeMetadata()->is_finalized()) + ? num_trees + : num_trees - 1; attempted_tree_t->scalar()() = ensemble_config.growing_metadata().num_trees_attempted(); diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index d4d405c3a9a894e333fdf2278625d510cdeef1fe..70aa0284a6fcd822b854888259b41cdf60d22af5 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -132,7 +132,6 @@ tf_cc_test( ":random_tree_gen", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:lib", "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -149,7 +148,6 @@ cc_library( deps = [ ":utils", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:testlib", ], @@ -197,7 +195,6 @@ tf_cc_test( srcs = ["quantiles/weighted_quantiles_buffer_test.cc"], deps = [ ":weighted_quantiles", - "//tensorflow/core", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -210,7 +207,6 @@ tf_cc_test( srcs = ["quantiles/weighted_quantiles_summary_test.cc"], deps = [ ":weighted_quantiles", - "//tensorflow/core", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -262,6 +258,8 @@ py_library( srcs = ["learner/batch/base_split_handler.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/boosted_trees:batch_ops_utils_py", + "//tensorflow/python:control_flow_ops", ], ) @@ -271,9 +269,13 @@ py_library( srcs_version = "PY2AND3", deps = [ ":base_split_handler", - "//tensorflow/contrib/boosted_trees:quantile_ops_py", "//tensorflow/contrib/boosted_trees:split_handler_ops_py", "//tensorflow/contrib/boosted_trees:stats_accumulator_ops_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", ], ) @@ -285,7 +287,15 @@ py_test( ":categorical_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", ], ) @@ -298,7 +308,14 @@ py_library( "//tensorflow/contrib/boosted_trees:quantile_ops_py", "//tensorflow/contrib/boosted_trees:split_handler_ops_py", "//tensorflow/contrib/boosted_trees:stats_accumulator_ops_py", - "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", ], ) @@ -310,7 +327,15 @@ py_test( ":ordinal_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", ], ) diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc index 16bffd9beccfad352820c805e08bec71f3705f42..43b00d4c6dc2e0066810012292874314215c41be 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc @@ -21,82 +21,14 @@ namespace tensorflow { namespace boosted_trees { namespace models { -namespace { -void CalculateTreesToKeep( - const boosted_trees::trees::DecisionTreeEnsembleConfig& config, - const std::vector& trees_to_drop, const int32 num_trees, - const bool only_finalized, std::vector* trees_to_keep) { - trees_to_keep->reserve(num_trees - trees_to_drop.size()); - - int32 index = 0; - // This assumes that trees_to_drop is a sorted list of tree ids. - for (int32 tree = 0; tree < num_trees; ++tree) { - if ((!trees_to_drop.empty() && index < trees_to_drop.size() && - trees_to_drop[index] == tree) || - (only_finalized && config.tree_metadata_size() > 0 && - !config.tree_metadata(tree).is_finalized())) { - ++index; - continue; - } - trees_to_keep->push_back(tree); - } -} - -void UpdatePredictions( - const int32 index_1, const int32 index_2, const float value, - tensorflow::TTypes::Matrix* output_predictions, - tensorflow::TTypes::Matrix* additional_output_predictions) { - (*output_predictions)(index_1, index_2) += value; - - if (additional_output_predictions != nullptr) { - (*additional_output_predictions)(index_1, index_2) += value; - } -} - -void UpdatePredictionsBasedOnTree( - const boosted_trees::trees::DecisionTreeEnsembleConfig& config, - const int32 tree_idx, const boosted_trees::utils::Example& example, - tensorflow::TTypes::Matrix* output_predictions, - tensorflow::TTypes::Matrix* additional_output_predictions) { - const boosted_trees::trees::DecisionTreeConfig& tree = config.trees(tree_idx); - const float tree_weight = config.tree_weights(tree_idx); - const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); - QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); - const auto& leaf_node = tree.nodes(leaf_idx); - QCHECK(leaf_node.has_leaf()) - << "Invalid leaf node: " << leaf_node.DebugString(); - if (leaf_node.leaf().has_sparse_vector()) { - const auto& leaf = leaf_node.leaf().sparse_vector(); - QCHECK_EQ(leaf.index_size(), leaf.value_size()); - for (size_t class_idx = 0; class_idx < leaf.index_size(); ++class_idx) { - const float value = tree_weight * leaf.value(class_idx); - - UpdatePredictions(example.example_idx, leaf.index(class_idx), value, - output_predictions, additional_output_predictions); - } - } else { - QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type"; - const auto& leaf = leaf_node.leaf().vector(); - for (size_t i = 0; i < leaf.value_size(); ++i) { - const float value = tree_weight * leaf.value(i); - UpdatePredictions(example.example_idx, i, value, output_predictions, - additional_output_predictions); - } - } -} - -} // namespace - void MultipleAdditiveTrees::Predict( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, - const bool only_finalized_trees, const std::vector& trees_to_drop, + const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, - tensorflow::thread::ThreadPool* worker_threads, - tensorflow::TTypes::Matrix output_predictions, - tensorflow::TTypes::Matrix no_dropout_predictions) { + tensorflow::thread::ThreadPool* const worker_threads, + tensorflow::TTypes::Matrix output_predictions) { // Zero out predictions as the model is additive. output_predictions.setZero(); - no_dropout_predictions.setZero(); // Get batch size. const int64 batch_size = features.batch_size(); @@ -104,27 +36,37 @@ void MultipleAdditiveTrees::Predict( return; } - // Prepare the list of trees to keep. - std::vector trees_to_keep; - CalculateTreesToKeep(config, trees_to_drop, config.trees_size(), - only_finalized_trees, &trees_to_keep); - // Lambda for doing a block of work. - auto update_predictions = [&config, &features, &trees_to_keep, &trees_to_drop, - &output_predictions, - &no_dropout_predictions](int64 start, int64 end) { + auto update_predictions = [&config, &features, &trees_to_include, + &output_predictions](int64 start, int64 end) { auto examples_iterable = features.examples_iterable(start, end); for (const auto& example : examples_iterable) { - for (const int32 tree_idx : trees_to_keep) { - UpdatePredictionsBasedOnTree(config, tree_idx, example, - &output_predictions, - &no_dropout_predictions); - } - - // Now do predictions for dropped trees - for (const int32 tree_idx : trees_to_drop) { - UpdatePredictionsBasedOnTree(config, tree_idx, example, - &no_dropout_predictions, nullptr); + for (const int32 tree_idx : trees_to_include) { + const boosted_trees::trees::DecisionTreeConfig& tree = + config.trees(tree_idx); + const float tree_weight = config.tree_weights(tree_idx); + const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); + QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); + const auto& leaf_node = tree.nodes(leaf_idx); + QCHECK(leaf_node.has_leaf()) + << "Invalid leaf node: " << leaf_node.DebugString(); + if (leaf_node.leaf().has_sparse_vector()) { + const auto& leaf = leaf_node.leaf().sparse_vector(); + QCHECK_EQ(leaf.index_size(), leaf.value_size()); + for (size_t logit_dim = 0; logit_dim < leaf.index_size(); + ++logit_dim) { + const float value = tree_weight * leaf.value(logit_dim); + output_predictions(example.example_idx, leaf.index(logit_dim)) += + value; + } + } else { + QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type"; + const auto& leaf = leaf_node.leaf().vector(); + for (size_t i = 0; i < leaf.value_size(); ++i) { + const float value = tree_weight * leaf.value(i); + output_predictions(example.example_idx, i) += value; + } + } } } }; diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h index fedade2026137ce43ff6b1cecd21f1e6c1461960..ee29a8aa797b96d41ec2d77bf831ee287d5443e7 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h @@ -32,15 +32,13 @@ namespace models { class MultipleAdditiveTrees { public: // Predict runs tree ensemble on the given batch and updates - // output predictions accordingly. The method also returns predictions that - // we would get if no dropout was applied. + // output predictions accordingly, for the given list of trees. static void Predict( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, - const bool only_finalized_trees, const std::vector& trees_to_drop, + const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, - thread::ThreadPool* const thread_pool, - TTypes::Matrix output_predictions, - TTypes::Matrix no_dropout_predictions); + tensorflow::thread::ThreadPool* const worker_threads, + tensorflow::TTypes::Matrix output_predictions); }; } // namespace models diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc index 5f0924b48f2a57c5ba8af1e564e344e8ffa1b676..4ca18bedb1054ef64c6d4b25bbad04842bab1a6a 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc @@ -57,22 +57,14 @@ TEST_F(MultipleAdditiveTreesTest, Empty) { DecisionTreeEnsembleConfig tree_ensemble_config; auto output_tensor = AsTensor({9.0f, 23.0f}, {2, 1}); auto output_matrix = output_tensor.matrix(); - auto no_dropout_output_matrix = output_tensor.matrix(); // Predict for both instances. tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", kNumThreadsSingleThreaded); - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, + &threads, output_matrix); EXPECT_EQ(0, output_matrix(0, 0)); EXPECT_EQ(0, output_matrix(1, 0)); - - // There was no dropout - for (int i = 0; i < 2; ++i) { - EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); - } } TEST_F(MultipleAdditiveTreesTest, SingleClass) { @@ -101,89 +93,48 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { auto output_tensor = AsTensor({0.0f, 0.0f}, {2, 1}); auto output_matrix = output_tensor.matrix(); - auto no_dropout_output_tensor = AsTensor({0.0f, 0.0f}, {2, 1}); - auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); - tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", kNumThreadsSingleThreaded); // Normal case. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, + batch_features_, &threads, output_matrix); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). - - // No dropout predictions are the same. - for (int i = 0; i < 2; ++i) { - EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); - } } // Weighted case { DecisionTreeEnsembleConfig weighted = tree_ensemble_config; weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); - MultipleAdditiveTrees::Predict(weighted, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, + output_matrix); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0)); // -0.4 (bias) + 0.9 (leaf 1). EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9 * 3.2, output_matrix(1, 0)); - - // No dropout predictions are the same. - for (int i = 0; i < 2; ++i) { - EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); - } } // Drop first tree. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {0}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1). - - // No dropout predictions - EXPECT_FLOAT_EQ( - -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). } // Drop second tree. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {1}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias). EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias). - - // No dropout predictions - EXPECT_FLOAT_EQ( - -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). } // Drop all trees. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {0, 1}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0)); - - // No dropout predictions - EXPECT_FLOAT_EQ( - -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). } } @@ -218,37 +169,22 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { auto output_tensor = AsTensor({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2}); auto output_matrix = output_tensor.matrix(); - auto no_dropout_output_tensor = - AsTensor({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2}); - auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); - // Normal case. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, + batch_features_, &threads, output_matrix); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1) EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias) - - // No dropout predictions are the same. - for (int i = 0; i < 2; ++i) { - for (int j = 0; j < 2; ++j) { - EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j)); - } - } } // Weighted case. { DecisionTreeEnsembleConfig weighted = tree_ensemble_config; weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); - MultipleAdditiveTrees::Predict(weighted, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, + output_matrix); // bias EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0)); // bias + leaf 2 @@ -260,60 +196,30 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { } // Dropout first tree. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {0}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2) EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1)); - - // No dropout predictions - EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) - EXPECT_FLOAT_EQ( - -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) - EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) } // Dropout second tree. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {1}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias) EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias) - - // No dropout predictions - EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) - EXPECT_FLOAT_EQ( - -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) - EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) } // Drop both trees. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {0, 1}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1)); EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0)); EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1)); - - // No dropout predictions - EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) - EXPECT_FLOAT_EQ( - -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) - EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) } } @@ -349,29 +255,16 @@ TEST_F(MultipleAdditiveTreesTest, DenseLeaves) { AsTensor({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3}); auto output_matrix = output_tensor.matrix(); - auto no_dropout_output_tensor = - AsTensor({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3}); - auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); - // Normal case. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, + batch_features_, &threads, output_matrix); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2) EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2) EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (tree1) + 0.9 (leaf 1) EXPECT_FLOAT_EQ(0.1f, output_matrix(1, 1)); // -0.7 (tree1) + 0.8 (leaf 1) EXPECT_FLOAT_EQ(3.7f, output_matrix(1, 2)); // 3.0 (tree1) + 0.7 (leaf 1) - - // No dropout predictions are the same. - for (int i = 0; i < 2; ++i) { - for (int j = 0; j < 3; ++j) { - EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j)); - } - } } } diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index dad3b4e10deff7b8fb3a2a393e27a5d7099984a1..c329c6d4f7363a7738b06648943fe1dbd065cce5 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -36,7 +36,7 @@ class WeightedQuantilesSummary { struct SummaryEntry { SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min, const WeightType& max) { - // Explicitely initialize all of memory (including padding from memory + // Explicitly initialize all of memory (including padding from memory // alignment) to allow the struct to be msan-resistant "plain old data". // // POD = http://en.cppreference.com/w/cpp/concept/PODType diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc index 31635906240d582f8ebbb9c8d14f1b2431409bc3..82b8e8c1c272ca415b5841f5ba9433e00173f8fa 100644 --- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc @@ -36,10 +36,7 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) { c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim, reduce_dim ? learner_config.num_classes() - 1 : learner_config.num_classes())}); - c->set_output(1, {c->Matrix(InferenceContext::kUnknownDim, - reduce_dim ? learner_config.num_classes() - 1 - : learner_config.num_classes())}); - c->set_output(2, {c->Vector(InferenceContext::kUnknownDim)}); + c->set_output(1, {c->Vector(InferenceContext::kUnknownDim)}); return Status::OK(); } @@ -63,7 +60,6 @@ REGISTER_OP("GradientTreesPrediction") .Input("sparse_int_feature_values: num_sparse_int_features * int64") .Input("sparse_int_feature_shapes: num_sparse_int_features * int64") .Output("predictions: float") - .Output("no_dropout_predictions: float") .Output("drop_out_tree_indices_weights: float") .SetShapeFn(ApplyGradientTreesPredictionShapeFn) .Doc(R"doc( @@ -90,8 +86,6 @@ sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices. sparse_int_feature_values: Rank 1 Tensors containing sparse int values. sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes. predictions: Rank 2 Tensor containing predictions per example per class. -no_dropout_predictions: The same as predictions, but using all trees (even -those that were dropped due to dropout). drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices and original weights of those trees during prediction. )doc"); diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py index 1ee3d71c5abda690060e245880cd5a4e764fd21c..27c288bbf78b3b593d0807e92ac7fd9afc4d2725 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py @@ -114,7 +114,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): name="create_tree") resources.initialize_resources(resources.shared_resources()).run() - result, _, _ = prediction_ops.gradient_trees_prediction( + result, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -175,7 +175,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 3 - result, _, _ = prediction_ops.gradient_trees_prediction( + result, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle2, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -241,7 +241,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): stamp_token=3, tree_ensemble_config=tree_ensemble_config.SerializeToString()) ]): - result, _, _ = prediction_ops.gradient_trees_prediction( + result, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -270,7 +270,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): stamp_token=3, tree_ensemble_config=tree_ensemble_config.SerializeToString()) ]): - result, _, _ = prediction_ops.gradient_trees_prediction( + result, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -293,7 +293,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): stamp_token=0, tree_ensemble_config="", name="restore_tree") my_saver = saver.Saver() my_saver.restore(sess, save_path) - result, _, _ = prediction_ops.gradient_trees_prediction( + result, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index 37595f1c75deab4db810d6ae49b57f56f417c52f..cf0958511350f82d548c56849f6179ae0f0215f5 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -151,22 +151,20 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=True) self.assertAllEqual([[0], [0]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -189,22 +187,20 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=True) self.assertAllClose([[-0.4], [-0.4]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -230,22 +226,20 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 3 - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=True) self.assertAllClose([[-0.4, 0.9], [-0.4, 0.9]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -285,27 +279,25 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=True) # The first example will get bias -0.4 from first tree and # leaf 4 payload of -0.9 hence -1.3, the second example will # get the same bias -0.4 and leaf 3 payload (sparse feature missing) # of 1.2 hence 0.8. self.assertAllClose([[-1.3], [0.8]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -346,25 +338,23 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=True) # All the examples should get only the bias since the second tree is # non-finalized self.assertAllClose([[-0.4], [-0.4]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -405,27 +395,25 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=True) # The first example will get bias -0.4 from first tree and # leaf 4 payload of -0.9 hence -1.3, the second example will # get the same bias -0.4 and leaf 3 payload (sparse feature missing) # of 1.2 hence 0.8. Note that the non-finalized tree is included. self.assertAllClose([[-1.3], [0.8]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -466,27 +454,25 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=True) # The first example will get bias -0.4 from first tree and # leaf 4 payload of -0.9 hence -1.3, the second example will # get the same bias -0.4 and leaf 3 payload (sparse feature missing) # of 1.2 hence 0.8. self.assertAllClose([[-1.3], [0.8]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -526,26 +512,24 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.TREE_PER_CLASS) - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=True) # The first example will get bias class 1 -0.2 from first tree and # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2], # the second example will get the same bias class 1 -0.2 and leaf 3 # payload of class 1 1.2 hence [0.0, 1.0]. self.assertAllClose([[0.5, -0.2], [0, 1.0]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -588,26 +572,24 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.FULL_HESSIAN) - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=False)) + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=False) # The first example will get bias class 1 -0.2 from first tree and # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2], # the second example will get the same bias class 1 -0.2 and leaf 3 # payload of class 1 1.2 and class 2-0.7 hence [0.0, 1.0, -0.7]. self.assertAllClose([[0.5, -0.2, 0.0], [0, 1.0, -0.7]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -649,26 +631,24 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.FULL_HESSIAN) - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=False)) + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=False) # The first example will get bias class 1 -0.2 and -2 for class 2 from # first tree and leaf 2 payload (sparse feature missing) of 0.5 hence # 0.5, -0.2], the second example will get the same bias and leaf 3 payload # of class 1 1.2 and class 2-0.7 hence [0.0, 1.0, -2.7]. self.assertAllClose([[0.5, -0.2, -2.0], [0, 1.0, -2.7]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -697,7 +677,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): with self.test_session(): # Empty tree ensenble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - # Add 10 trees with some weights. + # Add 1000 trees with some weights. for i in range(0, 999): tree = tree_ensemble_config.trees.add() tree_ensemble_config.tree_metadata.add().is_finalized = True @@ -717,7 +697,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): name="existing") resources.initialize_resources(resources.shared_resources()).run() - result, result_no_dropout, dropout_info = self._get_predictions( + result, dropout_info = self._get_predictions( tree_ensemble_handle, learner_config=learner_config, apply_dropout=True, @@ -729,10 +709,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertIn(dropout_info[0].size, range(400, 601)) self.assertEqual(dropout_info[0].size, dropout_info[1].size) - self.assertEqual(result.eval().size, result_no_dropout.eval().size) - for i in range(result.eval().size): - self.assertNotEqual(result.eval()[i], result_no_dropout.eval()[i]) - for i in range(dropout_info[0].size): dropped_index = dropout_info[0][i] dropped_weight = dropout_info[1][i] @@ -741,17 +717,19 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertEqual(dropped_index + 1, dropped_weight) # Don't apply dropout. - result, result_no_dropout, dropout_info = self._get_predictions( + result_no_dropout, no_dropout_info = self._get_predictions( tree_ensemble_handle, learner_config=learner_config, apply_dropout=False, apply_averaging=False, center_bias=False) - # We expect none of the trees were dropped. - self.assertAllEqual([[], []], dropout_info.eval()) + self.assertEqual(result.eval().size, result_no_dropout.eval().size) + for i in range(result.eval().size): + self.assertNotEqual(result.eval()[i], result_no_dropout.eval()[i]) - self.assertAllEqual(result.eval(), result_no_dropout.eval()) + # We expect none of the trees were dropped. + self.assertAllEqual([[], []], no_dropout_info.eval()) def testDropoutCenterBiasNoGrowingMeta(self): # This is for normal non-batch mode where ensemble does not contain the tree @@ -780,20 +758,19 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): name="existing") resources.initialize_resources(resources.shared_resources()).run() - result, result_no_dropout, dropout_info = self._get_predictions( + result, dropout_info = self._get_predictions( tree_ensemble_handle, learner_config=learner_config, apply_dropout=True, apply_averaging=False, center_bias=False) - result_center, result_no_dropout_center, dropout_info_center = ( - self._get_predictions( - tree_ensemble_handle, - learner_config=learner_config, - apply_dropout=True, - apply_averaging=False, - center_bias=True)) + result_center, dropout_info_center = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config, + apply_dropout=True, + apply_averaging=False, + center_bias=True) dropout_info = dropout_info.eval() dropout_info_center = dropout_info_center.eval() @@ -820,9 +797,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertEqual(num_trees - 1, dropout_info_center[0][num_dropped_center - 1]) - self.assertAllEqual(result_no_dropout.eval(), - result_no_dropout_center.eval()) - def testDropoutCenterBiasWithGrowingMeta(self): # This is batch mode where ensemble already contains the tree that we are # building. This tree should never be dropped. @@ -854,20 +828,19 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): name="existing") resources.initialize_resources(resources.shared_resources()).run() - result, result_no_dropout, dropout_info = self._get_predictions( + result, dropout_info = self._get_predictions( tree_ensemble_handle, learner_config=learner_config, apply_dropout=True, apply_averaging=False, center_bias=False) - result_center, result_no_dropout_center, dropout_info_center = ( - self._get_predictions( - tree_ensemble_handle, - learner_config=learner_config, - apply_dropout=True, - apply_averaging=False, - center_bias=True)) + result_center, dropout_info_center = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config, + apply_dropout=True, + apply_averaging=False, + center_bias=True) dropout_info = dropout_info.eval() dropout_info_center = dropout_info_center.eval() @@ -893,9 +866,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertNotEqual(num_trees - 1, dropout_info_center[0][num_dropped_center - 1]) - self.assertAllEqual(result_no_dropout.eval(), - result_no_dropout_center.eval()) - def testDropoutSeed(self): with self.test_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() @@ -918,67 +888,63 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): name="empty") resources.initialize_resources(resources.shared_resources()).run() - _, result_no_dropout_1, dropout_info_1 = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=True, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) - - _, result_no_dropout_2, dropout_info_2 = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=True, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + _, dropout_info_1 = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=True, + apply_averaging=False, + center_bias=False, + reduce_dim=True) + + _, dropout_info_2 = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=True, + apply_averaging=False, + center_bias=False, + reduce_dim=True) # Different seed. - _, result_no_dropout_3, dropout_info_3 = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - 112314, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=True, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + _, dropout_info_3 = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + 112314, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=True, + apply_averaging=False, + center_bias=False, + reduce_dim=True) # First seed with centering bias. - _, result_no_dropout_4, dropout_info_4 = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=True, - apply_averaging=False, - center_bias=True, - reduce_dim=True)) + _, dropout_info_4 = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2 + ], [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, + self._sparse_float_shape2], [self._sparse_int_indices1], + [self._sparse_int_values1], [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=True, + apply_averaging=False, + center_bias=True, + reduce_dim=True) # The same seed returns the same results. self.assertAllEqual(dropout_info_1.eval(), dropout_info_2.eval()) @@ -991,31 +957,46 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertEqual( len(dropout_info_4.eval()[0]) + 1, len(dropout_info_1.eval()[0])) - # Predictions without dropout are all the same. - result, result_no_dropout, _ = prediction_ops.gradient_trees_prediction( + def testDropOutZeroProb(self): + with self.test_session(): + # Empty tree ensenble. + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + # Add 1000 trees with some weights. + for i in range(0, 999): + tree = tree_ensemble_config.trees.add() + tree_ensemble_config.tree_metadata.add().is_finalized = True + _append_to_leaf(tree.nodes.add().leaf, 0, -0.4) + tree_ensemble_config.tree_weights.append(i + 1) + + # Dropout with 0 probability. + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.dropout.dropout_probability = 0.0 + learner_config.learning_rate_tuner.dropout.learning_rate = 1.0 + learner_config.num_classes = 2 + + # Apply dropout, but expect nothing dropped. + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="existing") + resources.initialize_resources(resources.shared_resources()).run() + + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), + learner_config=learner_config, + apply_dropout=True, + apply_averaging=False, + center_bias=False) + + result_no_dropout, _ = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config, apply_dropout=False, apply_averaging=False, - center_bias=False, - reduce_dim=True) + center_bias=False) - self.assertAllCloseAccordingToType(result.eval(), - result_no_dropout.eval()) - self.assertAllCloseAccordingToType(result.eval(), - result_no_dropout_1.eval()) - self.assertAllCloseAccordingToType(result.eval(), - result_no_dropout_2.eval()) - self.assertAllCloseAccordingToType(result.eval(), - result_no_dropout_3.eval()) - self.assertAllCloseAccordingToType(result.eval(), - result_no_dropout_4.eval()) + self.assertAllEqual([[], []], dropout_info.eval()) + self.assertAllClose(result.eval(), result_no_dropout.eval()) def testAveragingAllTrees(self): with self.test_session(): @@ -1066,17 +1047,14 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() # Do averaging. - result, result_no_dropout, dropout_info = self._get_predictions( + result, dropout_info = self._get_predictions( tree_ensemble_handle, learner_config, apply_averaging=True) - pattern_result, pattern_result_no_dropout, pattern_dropout_info = ( - self._get_predictions( - adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False)) + pattern_result, pattern_dropout_info = (self._get_predictions( + adjusted_tree_ensemble_handle, + learner_config_no_averaging, + apply_averaging=False)) - self.assertAllEqual(result_no_dropout.eval(), - pattern_result_no_dropout.eval()) self.assertAllEqual(result.eval(), pattern_result.eval()) self.assertAllEqual(dropout_info.eval(), pattern_dropout_info.eval()) @@ -1137,22 +1115,16 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() - result_1, result_no_dropout_1, dropout_info_1 = self._get_predictions( + result_1, dropout_info_1 = self._get_predictions( tree_ensemble_handle, learner_config_1, apply_averaging=True) - result_2, result_no_dropout_2, dropout_info_2 = self._get_predictions( + result_2, dropout_info_2 = self._get_predictions( tree_ensemble_handle, learner_config_2, apply_averaging=True) - pattern_result, pattern_result_no_dropout, pattern_dropout_info = ( - self._get_predictions( - adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False)) - - self.assertAllEqual(result_no_dropout_1.eval(), - pattern_result_no_dropout.eval()) - self.assertAllEqual(result_no_dropout_2.eval(), - pattern_result_no_dropout.eval()) + pattern_result, pattern_dropout_info = self._get_predictions( + adjusted_tree_ensemble_handle, + learner_config_no_averaging, + apply_averaging=False) self.assertAllEqual(result_1.eval(), pattern_result.eval()) self.assertAllEqual(result_2.eval(), pattern_result.eval()) @@ -1206,17 +1178,14 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() - result, result_no_dropout, dropout_info = self._get_predictions( + result, dropout_info = self._get_predictions( tree_ensemble_handle, learner_config, apply_averaging=True) - pattern_result, pattern_result_no_dropout, pattern_dropout_info = ( - self._get_predictions( - adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False)) + pattern_result, pattern_dropout_info = (self._get_predictions( + adjusted_tree_ensemble_handle, + learner_config_no_averaging, + apply_averaging=False)) - self.assertAllEqual(result_no_dropout.eval(), - pattern_result_no_dropout.eval()) self.assertAllEqual(result.eval(), pattern_result.eval()) self.assertAllEqual(dropout_info.eval(), pattern_dropout_info.eval()) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 2d28e0a9f160373b4565d83e9b57de401a052bd6..f8f4b43a072a91f1563b20d6ba3aef82fd4b9896 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -56,7 +56,6 @@ PREDICTIONS = "predictions" PARTITION_IDS = "partition_ids" NUM_LAYERS_ATTEMPTED = "num_layers" NUM_TREES_ATTEMPTED = "num_trees" -PREDICTIONS_NO_DROPOUT = "predictions_no_dropout" _FEATURE_NAME_TEMPLATE = "%s_%d" @@ -70,15 +69,13 @@ def _get_column_by_index(tensor, indices): return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1]) -def _make_predictions_dict(stamp, logits, logits_no_dropout, partition_ids, - ensemble_stats): +def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats): """Returns predictions for the given logits and n_classes. Args: stamp: The ensemble stamp. logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. - logits_no_dropout: A rank 2 `Tensor` with shape [batch_size, n_classes - 1] - that contains predictions when no dropout was applied. + that contains predictions when no dropout was applied. partition_ids: A rank 1 `Tensor` with shape [batch_size]. ensemble_stats: A TreeEnsembleStatsOp result tuple. @@ -88,9 +85,7 @@ def _make_predictions_dict(stamp, logits, logits_no_dropout, partition_ids, result = {} result[ENSEMBLE_STAMP] = stamp result[PREDICTIONS] = logits - result[PREDICTIONS_NO_DROPOUT] = logits_no_dropout result[PARTITION_IDS] = partition_ids - result[NUM_LAYERS_ATTEMPTED] = ensemble_stats.attempted_layers result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees return result @@ -348,6 +343,57 @@ class GradientBoostedDecisionTreeModel(object): learner_pb2.LearnerConfig.TREE_PER_CLASS and learner_config.num_classes == 2) + def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): + """Runs prediciton and returns a dictionary of the prediction results. + + Args: + ensemble_handle: ensemble resource handle. + ensemble_stamp: stamp of ensemble resource. + mode: learn.ModeKeys.TRAIN or EVAL or INFER. + + Returns: + a dictionary of prediction results - + ENSEMBLE_STAMP, PREDICTION, PARTITION_IDS, + NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPED. + """ + ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, + ensemble_stamp) + # We don't need dropout info - we can always restore it based on the + # seed. + apply_dropout, seed = _dropout_params(mode, ensemble_stats) + # Make sure ensemble stats run. This will check that the ensemble has + # the right stamp. + with ops.control_dependencies(ensemble_stats): + predictions, _ = prediction_ops.gradient_trees_prediction( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) + partition_ids = prediction_ops.gradient_trees_partition_examples( + ensemble_handle, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + use_locking=True) + + return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, + ensemble_stats) + def predict(self, mode): """Returns predictions given the features and mode. @@ -360,7 +406,6 @@ class GradientBoostedDecisionTreeModel(object): Raises: ValueError: if features is not valid. """ - apply_averaging = mode != learn.ModeKeys.TRAIN # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device @@ -409,83 +454,13 @@ class GradientBoostedDecisionTreeModel(object): # Once updated, use the local model for prediction. with ops.control_dependencies([refresh_local_ensemble]): - ensemble_stats = training_ops.tree_ensemble_stats( - local_ensemble_handle, ensemble_stamp) - # We don't need dropout info - we can always restore it based on the - # seed. - apply_dropout, seed = _dropout_params(mode, ensemble_stats) - # Make sure ensemble stats run. This will check that the ensemble has - # the right stamp. - with ops.control_dependencies(ensemble_stats): - predictions, predictions_no_dropout, _ = ( - prediction_ops.gradient_trees_prediction( - local_ensemble_handle, - seed, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - learner_config=self._learner_config_serialized, - apply_dropout=apply_dropout, - apply_averaging=apply_averaging, - use_locking=True, - center_bias=self._center_bias, - reduce_dim=self._reduce_dim)) - partition_ids = prediction_ops.gradient_trees_partition_examples( - local_ensemble_handle, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - use_locking=True) - + return self._predict_and_return_dict(local_ensemble_handle, + ensemble_stamp, mode) else: + # Use ensemble_handle directly, if colocated. with ops.device(self._ensemble_handle.device): - ensemble_stats = training_ops.tree_ensemble_stats( - self._ensemble_handle, ensemble_stamp) - # We don't need dropout info - we can always restore it based on the - # seed. - apply_dropout, seed = _dropout_params(mode, ensemble_stats) - # Make sure ensemble stats run. This will check that the ensemble has - # the right stamp. - with ops.control_dependencies(ensemble_stats): - predictions, predictions_no_dropout, _ = ( - prediction_ops.gradient_trees_prediction( - self._ensemble_handle, - seed, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - learner_config=self._learner_config_serialized, - apply_dropout=apply_dropout, - apply_averaging=apply_averaging, - use_locking=True, - center_bias=self._center_bias, - reduce_dim=self._reduce_dim)) - partition_ids = prediction_ops.gradient_trees_partition_examples( - self._ensemble_handle, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - use_locking=True) - - return _make_predictions_dict(ensemble_stamp, predictions, - predictions_no_dropout, partition_ids, - ensemble_stats) + return self._predict_and_return_dict(self._ensemble_handle, + ensemble_stamp, mode) def train(self, loss, predictions_dict, labels): """Grows a new tree and adds it to the ensemble. @@ -546,8 +521,8 @@ class GradientBoostedDecisionTreeModel(object): hessians = array_ops.stack(hessian_list, axis=1) # Choose the class for which the tree is built (one vs rest). - class_id = predictions_dict[NUM_TREES_ATTEMPTED] % num_classes - class_id = math_ops.to_int32(class_id) + class_id = math_ops.to_int32( + predictions_dict[NUM_TREES_ATTEMPTED] % num_classes) # Use class id tensor to get the column with that index from gradients # and hessians. @@ -711,7 +686,7 @@ class GradientBoostedDecisionTreeModel(object): handler_results = batch_ops_utils.run_handler_scheduled_ops( handler_reads, ensemble_stamp, worker_device) per_handler_updates = {} - # Two values per handler. First one is if the the handler is active for the + # Two values per handler. First one is if the handler is active for the # current layer. The second one is if the handler is going to be active # for the next layer. subsampling_type = self._learner_config.WhichOneof("feature_fraction") @@ -803,7 +778,10 @@ class GradientBoostedDecisionTreeModel(object): active_tree, active_layer, dropout_seed, class_id), control_flow_ops.no_op)) - # Calculate the loss to be reported - use the predictions without dropout. + # Calculate the loss to be reported. + # Note, the loss is calculated from the prediction considering dropouts, so + # that the value might look staggering over steps when the dropout ratio is + # high. eval_loss might be referred instead in the aspect of convergence. return control_flow_ops.group(*ensemble_update_ops) def _get_weights(self, hessian_shape, hessians): diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 45c3bbadfc8d6300841cbc256c894e3bb14cb44e..77e6ecb443dd3f0f7a96b7453f558d58f01c7a21 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -44,9 +44,84 @@ class DecisionTreeEnsembleResource : public StampedResource { return *decision_tree_ensemble_; } - boosted_trees::trees::DecisionTreeEnsembleConfig* - mutable_decision_tree_ensemble() { - return decision_tree_ensemble_; + int32 num_trees() const { return decision_tree_ensemble_->trees_size(); } + + bool InitFromSerialized(const string& serialized, const int64 stamp_token) { + if (ParseProtoUnlimited(decision_tree_ensemble_, serialized)) { + set_stamp(stamp_token); + return true; + } + return false; + } + + string SerializeAsString() const { + return decision_tree_ensemble_->SerializeAsString(); + } + + // Increment num_layers_attempted and num_trees_attempted in growing_metadata + // if the tree is finalized. + void IncrementAttempts() { + boosted_trees::trees::GrowingMetadata* const growing_metadata = + decision_tree_ensemble_->mutable_growing_metadata(); + growing_metadata->set_num_layers_attempted( + growing_metadata->num_layers_attempted() + 1); + const int num_trees = decision_tree_ensemble_->trees_size(); + if (num_trees <= 0 || LastTreeMetadata()->is_finalized()) { + growing_metadata->set_num_trees_attempted( + growing_metadata->num_trees_attempted() + 1); + } + } + + boosted_trees::trees::DecisionTreeConfig* AddNewTree(const float weight) { + // Adding a tree as well as a weight and a tree_metadata. + decision_tree_ensemble_->add_tree_weights(weight); + boosted_trees::trees::DecisionTreeMetadata* const metadata = + decision_tree_ensemble_->add_tree_metadata(); + metadata->set_num_layers_grown(1); + return decision_tree_ensemble_->add_trees(); + } + + void RemoveLastTree() { + QCHECK_GT(decision_tree_ensemble_->trees_size(), 0); + decision_tree_ensemble_->mutable_trees()->RemoveLast(); + decision_tree_ensemble_->mutable_tree_weights()->RemoveLast(); + decision_tree_ensemble_->mutable_tree_metadata()->RemoveLast(); + } + + boosted_trees::trees::DecisionTreeConfig* LastTree() { + const int32 tree_size = decision_tree_ensemble_->trees_size(); + QCHECK_GT(tree_size, 0); + return decision_tree_ensemble_->mutable_trees(tree_size - 1); + } + + boosted_trees::trees::DecisionTreeMetadata* LastTreeMetadata() { + const int32 metadata_size = decision_tree_ensemble_->tree_metadata_size(); + QCHECK_GT(metadata_size, 0); + return decision_tree_ensemble_->mutable_tree_metadata(metadata_size - 1); + } + + // Retrieves tree weights and returns as a vector. + std::vector GetTreeWeights() const { + return {decision_tree_ensemble_->tree_weights().begin(), + decision_tree_ensemble_->tree_weights().end()}; + } + + float GetTreeWeight(const int32 index) const { + return decision_tree_ensemble_->tree_weights(index); + } + + // Sets the weight of i'th tree, and increment num_updates in tree_metadata. + void SetTreeWeight(const int32 index, const float weight, + const int32 increment_num_updates) { + QCHECK_GE(index, 0); + QCHECK_LT(index, num_trees()); + decision_tree_ensemble_->set_tree_weights(index, weight); + if (increment_num_updates != 0) { + const int32 num_updates = decision_tree_ensemble_->tree_metadata(index) + .num_tree_weight_updates(); + decision_tree_ensemble_->mutable_tree_metadata(index) + ->set_num_tree_weight_updates(num_updates + increment_num_updates); + } } // Resets the resource and frees the protos in arena. @@ -64,7 +139,7 @@ class DecisionTreeEnsembleResource : public StampedResource { mutex* get_mutex() { return &mu_; } - private: + protected: protobuf::Arena arena_; mutex mu_; boosted_trees::trees::DecisionTreeEnsembleConfig* decision_tree_ensemble_; diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 09ec7e42c7eede97b9c7eeee329fe0649365869e..56f930a9a8d32c5c3a025163ef56c9562f17d864 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -23,7 +23,9 @@ load( filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", @@ -34,9 +36,7 @@ filegroup( tf_kernel_library( name = "bigquery_reader_ops", - srcs = [ - "bigquery_reader_ops.cc", - ], + srcs = ["bigquery_reader_ops.cc"], visibility = ["//visibility:public"], deps = [ ":bigquery_table_accessor", @@ -50,12 +50,8 @@ tf_kernel_library( cc_library( name = "bigquery_table_accessor", - srcs = [ - "bigquery_table_accessor.cc", - ], - hdrs = [ - "bigquery_table_accessor.h", - ], + srcs = ["bigquery_table_accessor.cc"], + hdrs = ["bigquery_table_accessor.h"], copts = tf_copts(), linkstatic = 1, deps = [ @@ -64,7 +60,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform/cloud:curl_http_request", "//tensorflow/core/platform/cloud:google_auth_provider", - "//tensorflow/core/platform/cloud:http_request", ], alwayslink = 1, ) @@ -88,8 +83,6 @@ tf_cc_test( tf_proto_library( name = "bigquery_table_partition_proto", - srcs = [ - "bigquery_table_partition.proto", - ], + srcs = ["bigquery_table_partition.proto"], cc_api_version = 2, ) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index ceb583abe0796ec9748e752f112ce9e368bdd8c0..d76ddf8c657b9b5d02bbdc4d6759053396dcd6d2 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -39,7 +39,6 @@ class TPUClusterResolver(ClusterResolver): """ def __init__(self, - api_definition, project, zone, tpu_names, @@ -52,8 +51,6 @@ class TPUClusterResolver(ClusterResolver): for the IP addresses and ports of each Cloud TPU listed. Args: - api_definition: (Alpha only) A copy of the JSON API definitions for - Cloud TPUs. This will be removed once Cloud TPU enters beta. project: Name of the GCP project containing Cloud TPUs zone: Zone where the TPUs are located tpu_names: A list of names of the target Cloud TPUs. @@ -83,11 +80,13 @@ class TPUClusterResolver(ClusterResolver): raise ImportError('googleapiclient must be installed before using the ' 'TPU cluster resolver') - # TODO(frankchn): Remove once Cloud TPU API Definitions are public and - # replace with discovery.build('tpu', 'v1') - self._service = discovery.build_from_document( - api_definition, - credentials=self._credentials) + # TODO(b/67375680): Remove custom URL once TPU APIs are finalized + self._service = discovery.build( + 'tpu', + 'v1', + credentials=self._credentials, + discoveryServiceUrl='https://storage.googleapis.com' + '/tpu-api-definition/v1alpha1.json') else: self._service = service diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index c249a2855622581534534a94af9991d12b73f5e9..8744fc492ff67064bff2097c99be5af8a739b60d 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -245,7 +245,7 @@ if (tensorflow_ENABLE_GPU) "#define CUDA_CUDA_CONFIG_H_\n" "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n" "#define TF_CUDA_VERSION \"64_80\"\n" - "#define TF_CUDNN_VERSION \"64_5\"\n" + "#define TF_CUDNN_VERSION \"64_6\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" "#endif // CUDA_CUDA_CONFIG_H_\n" ) @@ -264,8 +264,23 @@ if (tensorflow_ENABLE_GPU) include_directories(${tensorflow_source_dir}/third_party/gpus) # add cuda libraries to tensorflow_EXTERNAL_LIBRARIES list(APPEND tensorflow_EXTERNAL_LIBRARIES ${CUDA_LIBRARIES}) - endif() -endif() + + # NOTE(mrry): Update these flags when the version of CUDA or cuDNN used + # in the default build is upgraded. + set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value + msvcp_dll_name=msvcp140.dll + cudart_dll_name=cudart64_80.dll + cuda_version_number=8.0 + nvcuda_dll_name=nvcuda.dll + cudnn_dll_name=cudnn64_6.dll + cudnn_version_number=6) + else(WIN32) + message(FATAL_ERROR "CMake GPU build is currently only supported on Windows.") + endif(WIN32) +else(tensorflow_ENABLE_GPU) + set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value + msvcp_dll_name=msvcp140.dll) +endif(tensorflow_ENABLE_GPU) # Find python executable include(FindPythonInterp) diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake index d98579d2077f0a3bc58e6466ee830e53f44f40cb..7b263806d733f0e1deafe3e8fdd9baf2bb6fd81f 100644 --- a/tensorflow/contrib/cmake/external/cub.cmake +++ b/tensorflow/contrib/cmake/external/cub.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip) -set(cub_HASH SHA256=b7ead9e291d34ffa8074243541c1380d63be63f88de23de8ee548db573b72ebe) +set(cub_URL https://github.com/NVlabs/cub/archive/1.7.4.zip) +set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31) set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive) diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index c5a101812710f0e6eb0aa8816acd2b395e7f7472..f3882e8cf76c6dad31371fc340de959c05411a2f 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -21,6 +21,8 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" + "${tensorflow_source_dir}/tensorflow/c/eager/tape.cc" + "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc" diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 61c6686ee0e50a69c173ef47dd1e72d9bb5a982f..65565aad7ea5469926f320839455cd884a343713 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -33,6 +33,8 @@ else(tensorflow_BUILD_ALL_KERNELS) "${tensorflow_source_dir}/tensorflow/core/kernels/matmul_op.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/no_op.h" "${tensorflow_source_dir}/tensorflow/core/kernels/no_op.cc" + "${tensorflow_source_dir}/tensorflow/core/kernels/ops_util.h" + "${tensorflow_source_dir}/tensorflow/core/kernels/ops_util.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/sendrecv_ops.h" "${tensorflow_source_dir}/tensorflow/core/kernels/sendrecv_ops.cc" ) @@ -65,6 +67,8 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/training_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc" @@ -74,6 +78,13 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) #"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/zero_initializer_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/bipartite_match_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/ops/distort_image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" @@ -167,6 +178,7 @@ endif(WIN32) file(GLOB_RECURSE tf_core_gpu_kernels_srcs "${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/zero_initializer_op_gpu.cu.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/*.cu.cc" ) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 78bccc08a36b3be922867b410c75cda87e5d0983..dc9973917e48e77a0ffe04a687cb205e6342f46a 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -84,6 +84,8 @@ GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(input_pipeline "${tensorflow_source_dir}/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(image_distort_image "${tensorflow_source_dir}/tensorflow/contrib/image/ops/distort_image_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(image_sirds "${tensorflow_source_dir}/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 3430439d4d0058905f6aa4e57af35a9a7d9909c5..e83618a94ecea28a46bab0ab7b3d8e2517102823 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -266,12 +266,14 @@ add_python_module("tensorflow/python/keras/_impl/keras/utils") add_python_module("tensorflow/python/keras/_impl/keras/wrappers") add_python_module("tensorflow/python/kernel_tests") add_python_module("tensorflow/python/kernel_tests/distributions") +add_python_module("tensorflow/python/kernel_tests/linalg") add_python_module("tensorflow/python/layers") add_python_module("tensorflow/python/lib") add_python_module("tensorflow/python/lib/core") add_python_module("tensorflow/python/lib/io") add_python_module("tensorflow/python/ops") add_python_module("tensorflow/python/ops/distributions") +add_python_module("tensorflow/python/ops/linalg") add_python_module("tensorflow/python/ops/losses") add_python_module("tensorflow/python/platform") add_python_module("tensorflow/python/platform/default") @@ -370,6 +372,8 @@ add_python_module("tensorflow/contrib/gan/python/eval") add_python_module("tensorflow/contrib/gan/python/eval/python") add_python_module("tensorflow/contrib/gan/python/features") add_python_module("tensorflow/contrib/gan/python/features/python") +add_python_module("tensorflow/contrib/gan/python/estimator") +add_python_module("tensorflow/contrib/gan/python/estimator/python") add_python_module("tensorflow/contrib/gan/python/losses") add_python_module("tensorflow/contrib/gan/python/losses/python") add_python_module("tensorflow/contrib/graph_editor") @@ -495,6 +499,7 @@ add_python_module("tensorflow/contrib/lookup") add_python_module("tensorflow/contrib/losses") add_python_module("tensorflow/contrib/losses/python") add_python_module("tensorflow/contrib/losses/python/losses") +add_python_module("tensorflow/contrib/losses/python/metric_learning") add_python_module("tensorflow/contrib/makefile") add_python_module("tensorflow/contrib/makefile/test") add_python_module("tensorflow/contrib/memory_stats") @@ -536,6 +541,8 @@ add_python_module("tensorflow/contrib/pi_examples/label_image/data") add_python_module("tensorflow/contrib/predictor") add_python_module("tensorflow/contrib/quantization") add_python_module("tensorflow/contrib/quantization/python") +add_python_module("tensorflow/contrib/quantize") +add_python_module("tensorflow/contrib/quantize/python") add_python_module("tensorflow/contrib/remote_fused_graph/pylib") add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python") add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python/ops") @@ -636,13 +643,8 @@ add_python_module("tensorflow/contrib/reduce_slice_ops/python/ops") # Generate the tensorflow.python.platform.build_info module. set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py") -if(tensorflow_ENABLE_GPU) - set(BUILD_CONFIG_STRING "cuda") -else(tensorflow_ENABLE_GPU) - set(BUILD_CONFIG_STRING "cpu") -endif(tensorflow_ENABLE_GPU) add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD - COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/build_info/gen_build_info.py --build_config ${BUILD_CONFIG_STRING} --raw_generate ${BUILD_INFO_PY}) + COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/build_info/gen_build_info.py --raw_generate ${BUILD_INFO_PY} ${tensorflow_BUILD_INFO_FLAGS}) ######################################################## @@ -774,6 +776,10 @@ GENERATE_PYTHON_OP_LIB("contrib_input_pipeline_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/input_pipeline/ops/gen_input_pipeline_ops.py) GENERATE_PYTHON_OP_LIB("contrib_image_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/image/ops/gen_image_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_image_distort_image_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/image/ops/gen_distort_image_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_image_sirds_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/image/ops/gen_single_image_random_dot_stereograms_ops.py) GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py) GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" @@ -840,6 +846,7 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc" "${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe.h" + "${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tensor.cc" "${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.cc" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc" diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 9385ac52e903e1f0f2436066f573af5359c46770..9bf45bab3041142206900bf96beeddefb3308ee4 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -27,6 +27,7 @@ if(WIN32) $ $ $ + $ $ $ $ @@ -63,6 +64,7 @@ add_library(tensorflow SHARED $ $ $ + $ $ $ $ diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index ba78e87ac04d365c4c28273768111ba1fb6e783d..24f21afdfcdcc6cf65f288365eb16011d3db0ee4 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -152,6 +152,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/training/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/image/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/integration_test.py" "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/python/kernel_tests/*_test.py" @@ -178,6 +179,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) # exclude the ones we don't want set(tf_test_src_py_exclude + # generally excluded + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py" + # Python source line inspection tests are flaky on Windows (b/36375074). "${tensorflow_source_dir}/tensorflow/python/debug/cli/analyzer_cli_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/cli/profile_analyzer_cli_test.py" @@ -187,18 +191,12 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" # generally not working - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/resource_variable_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/pprof_profiler_test.py" # flaky test "${tensorflow_source_dir}/tensorflow/python/profiler/internal/run_metadata_test.py" "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" - # requires scipy - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/*_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py" # flaky tests - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cwise_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cwise_ops_test.py" # takes very long to run "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py" # Loading resources in contrib doesn't seem to work on Windows "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/client/random_forest_test.py" @@ -212,44 +210,48 @@ if (tensorflow_BUILD_PYTHON_TESTS) if (WIN32) set(tf_test_src_py_exclude ${tf_test_src_py_exclude} - # generally excluded - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py" - # TODO: failing tests. # Nothing critical in here but should get this list down to [] # The failing list is grouped by failure source + # stl on windows handles overflows different "${tensorflow_source_dir}/tensorflow/python/kernel_tests/as_string_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cast_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/string_to_number_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/clip_ops_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/tensor_array_ops_test.py" # Needs portpicker. - # Matrix_set_diag failing on GPU on windows. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cholesky_op_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/diag_op_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg_ops_test.py" - # misc - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reshape_op_test.py" - "${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py" # Depends on gemmlowp -> pthread. + # Numerical issues, calculations off. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py" + # Float division by zero + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" + # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py" + # IteratorGetMax OutOfRangeError + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py" + # Depends on gemmlowp -> pthread + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py" # int32/int64 mixup + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py" + # Windows file management related issues. + "${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" # training tests "${tensorflow_source_dir}/tensorflow/python/training/basic_session_run_hooks_test.py" # Needs tf.contrib fix. - "${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows. "${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename. - "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/training/server_lib_test.py" # Test occasionally deadlocks. - - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops + "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_multi_gpu_test.py" # Fails on multiple GPUs. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" # numerical issues + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg_grad_test.py" # cudaSolver handle creation fails. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops # Dataset tests - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/dataset_constructor_op_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" # Broken tensorboard test due to cmake issues. - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 # tensor_forest tests (also note that we exclude the hybrid tests for now) @@ -258,8 +260,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py" # Bad placement. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/topn_test.py" # Results inaccurate "${tensorflow_source_dir}/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py" # No libcurl support - # Newly running on Windows since TensorBoard backend move. Fail on Windows and need debug. - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. # Dask.Dataframe bugs on Window Build "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py" "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py" @@ -268,39 +268,17 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Need extra build "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_distribution_test.py" "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/depthtospace_op_test.py" # QuantizeV2 + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/spacetodepth_op_test.py" # QuantizeV2 # Windows Path "${tensorflow_source_dir}/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py" #TODO: Fix path - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/models_test.py" - # Related to Windows Multiprocessing https://github.com/fchollet/keras/issues/5071 - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/engine/training_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/callbacks_test.py" - # Scipy needed - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py" - # Failing with TF 1.3 (TODO) - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py" + # Numpy upgrade needed? "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py" # Test should only be run manually "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/svd_op_test.py" ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 94aff13a49f5380d5804e190b33613fd42dcaebc..2108e42bce4eba1eed158fe85888f1699a69ba7e 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -173,12 +173,12 @@ class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): with self.test_session(): - x = constant_op.constant(3) - y_nc = math_ops.add(x, x, name="not_compiled") + x = constant_op.constant([[3]]) + y_nc = math_ops.matmul(x, x, name="not_compiled") with jit.experimental_jit_scope(): - y_c = math_ops.add(y_nc, y_nc, name="compiled") + y_c = math_ops.matmul(y_nc, y_nc, name="compiled") x_grads = gradients.gradients([y_c], [x])[0] - operations = x_grads.graph.get_operations() + operations = x.graph.get_operations() c_grad_ops = [ op for op in operations if "gradients/compiled" in op.name] nc_grad_ops = [ @@ -191,19 +191,19 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "No attr named"): ncg.get_attr("_XlaCompile") - # d/dx (4 * x) - self.assertAllClose(4, x_grads.eval()) + # d/dx (x ** 4) = 4 * (x ** 3) + self.assertAllClose([[108]], x_grads.eval()) def testCompilationGradientScopeNames(self): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(): # XlaScope 0 - a1 = constant_op.constant(1) - a1t = a1 + a1 + a1 = constant_op.constant([[1]]) + a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(): # XlaScope 1 - a2 = constant_op.constant(1) - a2t = a2 + a2 + a2 = constant_op.constant([[1]]) + a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) @@ -220,12 +220,12 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 0 - a1 = constant_op.constant(1) - a1t = a1 + a1 + a1 = constant_op.constant([[1]]) + a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 1 - a2 = constant_op.constant(1) - a2t = a2 + a2 + a2 = constant_op.constant([[1]]) + a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index d4214587cd1a0fa684710d37083028f9af0425d9..ae9413fdd63cafed306b10a0f68f3fc0315a22c3 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -54,7 +54,7 @@ tf_gen_op_wrapper_py( ) tf_custom_op_py_library( - name = "cudnn_rnn_py", + name = "cudnn_rnn_ops_py", srcs = [ "__init__.py", "python/ops/cudnn_rnn_ops.py", @@ -81,10 +81,67 @@ tf_custom_op_py_library( ], ) +tf_custom_op_py_library( + name = "cudnn_rnn_py", + srcs = [ + "__init__.py", + "python/layers/cudnn_rnn.py", + ], + dso = [ + ":python/ops/_cudnn_rnn_ops.so", + ], + kernels = [ + ":cudnn_rnn_kernels", + ":cudnn_rnn_ops_op_lib", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":cudnn_rnn_ops", + ":cudnn_rnn_ops_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + ], +) + cuda_py_test( name = "cudnn_rnn_ops_test", size = "large", srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], + additional_deps = [ + ":cudnn_rnn_ops_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python/ops/losses:losses", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], + shard_count = 6, + tags = [ + "manual", + "requires_cudnn5", + ], +) + +cuda_py_test( + name = "cudnn_rnn_test", + size = "large", + srcs = ["python/kernel_tests/cudnn_rnn_test.py"], additional_deps = [ ":cudnn_rnn_py", "//tensorflow/core:protos_all_py", @@ -114,7 +171,7 @@ cuda_py_test( size = "large", srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"], additional_deps = [ - ":cudnn_rnn_py", + ":cudnn_rnn_ops_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce8954bb09d7444a552d0ba6b3d9bb72cd919fd --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -0,0 +1,1050 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Cudnn RNN models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import os +import unittest + +import numpy as np + +from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn +from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops +from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import gradients_impl as gradients +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn as rnn_lib +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import saver as saver_lib + +CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM +CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU +CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU +CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH +CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION +CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION + +CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER +CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER +CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER +CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER + + +class CudnnTestModel(object): + """Model with convenient APIs for easier building and running test graph. + + The graph built is used by all tests below to avoid repeatedly building + similar test graphs. + """ + + def __init__(self, + rnn_mode, + num_layers, + num_units, + input_size, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + dtype=dtypes.float32, + training=False, + kernel_initializer=None, + bias_initializer=None): + if dtype not in (dtypes.float32, dtypes.float64): + raise ValueError("Invalid dtype: %s" % dtype) + self._dtype = dtype + + self._inputs = array_ops.placeholder( + dtype=dtype, shape=[None, None, input_size], name="inputs") + h = array_ops.placeholder( + dtype=dtype, shape=[None, None, num_units], name="h") + c = array_ops.placeholder( + dtype=dtype, shape=[None, None, num_units], name="c") + if rnn_mode == CUDNN_LSTM: + model_fn = cudnn_rnn.CudnnLSTM + self._initial_state = (h, c) + elif rnn_mode == CUDNN_GRU: + model_fn = cudnn_rnn.CudnnGRU + self._initial_state = (h,) + elif rnn_mode == CUDNN_RNN_TANH: + model_fn = cudnn_rnn.CudnnRNNTanh + self._initial_state = (h,) + elif rnn_mode == CUDNN_RNN_RELU: + model_fn = cudnn_rnn.CudnnRNNRelu + self._initial_state = (h,) + else: + raise ValueError("Invalid rnn_mode: %s" % rnn_mode) + self._rnn = model_fn( + num_layers, + num_units, + direction=direction, + dropout=dropout, + dtype=dtype, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer) + self._rnn.build([None, None, input_size]) + + self._outputs, self._output_state = self._rnn( + self._inputs, initial_state=self._initial_state, training=training) + + def _AddUp(self, outputs, output_state): + total = math_ops.reduce_sum(outputs) + for s in output_state: + total += math_ops.reduce_sum(s) + return total + + @property + def inputs(self): + return self._inputs + + @property + def initial_state(self): + return self._initial_state + + @property + def outputs(self): + return self._outputs + + @property + def output_state(self): + return self._output_state + + @property + def rnn(self): + return self._rnn + + @property + def total_sum(self): + return self._AddUp(self.outputs, self.output_state) + + def SynthesizeInput(self, seq_length, batch_size, seed=1234): + """Synthesizes input and initial state values for testing.""" + np.random.seed(seed) + num_layers = self._rnn.num_layers + dir_count = self._rnn.num_dirs + num_units = self._rnn.num_units + input_size = self._rnn.input_size + + np_dtype = np.float32 if self._dtype == dtypes.float32 else np.float64 + inputs = np.random.randn(seq_length, batch_size, + input_size).astype(np_dtype) + input_h = np.random.randn(num_layers * dir_count, batch_size, + num_units).astype(np_dtype) + if self._rnn.rnn_mode == CUDNN_LSTM: + input_c = np.random.randn(num_layers * dir_count, batch_size, + num_units).astype(np_dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + return inputs, initial_state + + def ZeroState(self, batch_size): + num_layers = self._rnn.num_layers + dir_count = self._rnn.num_dirs + num_units = self._rnn.num_units + + np_dtype = np.float32 if self._dtype == dtypes.float32 else np.float64 + input_h = np.zeros((num_layers * dir_count, batch_size, + num_units)).astype(np_dtype) + if self._rnn.rnn_mode == CUDNN_LSTM: + input_c = np.zeros((num_layers * dir_count, batch_size, + num_units)).astype(np_dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + return initial_state + + def FProp(self, inputs_t, initial_state_t, training): + """Builds additional subgraph with given inputs and state. + + Args: + inputs_t: a tensor. + initial_state_t: a tensor. + training: boolean, true if training mode. + Returns: + A tensor of the forward pass output of the model. + """ + outputs, output_state = self._rnn( + inputs_t, initial_state=initial_state_t, training=training) + return self._AddUp(outputs, output_state) + + def Feed(self, sess, inputs, initial_state=None, return_sum=True): + """Runs graph with given inputs and initial state.""" + batch_size = inputs.shape[1] + if initial_state is None: + initial_state = self.ZeroState(batch_size) + if return_sum: + return sess.run( + self.total_sum, + feed_dict={self.inputs: inputs, + self.initial_state: initial_state}) + else: + return sess.run( + [self.outputs, self.output_state], + feed_dict={self.inputs: inputs, + self.initial_state: initial_state}) + + +def _CreateCudnnCompatibleCanonicalRNN(rnn, inputs, is_bidi=False, scope=None): + mode = rnn.rnn_mode + num_units = rnn.num_units + num_layers = rnn.num_layers + + # To reuse cuDNN-trained models, must use cudnn compatible rnn cells. + if mode == CUDNN_LSTM: + single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleLSTMCell(num_units) + elif mode == CUDNN_GRU: + single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units) + elif mode == CUDNN_RNN_TANH: + single_cell = (lambda: rnn_cell_impl.BasicRNNCell(num_units, math_ops.tanh)) + elif mode == CUDNN_RNN_RELU: + single_cell = ( + lambda: rnn_cell_impl.BasicRNNCell(num_units, gen_nn_ops.relu)) + else: + raise ValueError("%s is not supported!" % mode) + + if not is_bidi: + cell = rnn_cell_impl.MultiRNNCell( + [single_cell() for _ in range(num_layers)]) + return rnn_lib.dynamic_rnn( + cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope) + else: + cells_fw = [single_cell() for _ in range(num_layers)] + cells_bw = [single_cell() for _ in range(num_layers)] + + (outputs, output_state_fw, + output_state_bw) = contrib_rnn_lib.stack_bidirectional_dynamic_rnn( + cells_fw, + cells_bw, + inputs, + dtype=dtypes.float32, + time_major=True, + scope=scope) + return outputs, (output_state_fw, output_state_bw) + + +class CudnnRNNTestBasic(TensorFlowTestCase): + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testLayerBasic(self): + num_layers = 4 + num_units = 2 + batch_size = 8 + direction = CUDNN_RNN_UNIDIRECTION + dir_count = 1 + + with vs.variable_scope("main"): + kernel_initializer = init_ops.constant_initializer(0.) + bias_initializer = init_ops.constant_initializer(0.) + inputs = random_ops.random_uniform([ + num_layers * dir_count, batch_size, num_units], dtype=dtypes.float32) + + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + + # Build the layer + outputs1, _ = lstm(inputs) + # Reuse the layer + outputs2, _ = lstm(inputs) + + total_sum1 = math_ops.reduce_sum(outputs1) + total_sum2 = math_ops.reduce_sum(outputs2) + + with vs.variable_scope("main", reuse=True): + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + + # Reuse the layer + outputs3, _ = lstm(inputs) + total_sum3 = math_ops.reduce_sum(outputs3) + + self.assertEqual(1, len(variables.trainable_variables())) + self.assertEqual(1, len(ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))) + self.assertEqual("main/awesome_lstm/opaque_kernel", + variables.trainable_variables()[0].op.name) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + (total_sum1_v, total_sum2_v, total_sum3_v) = sess.run( + [total_sum1, total_sum2, total_sum3]) + self.assertEqual(0, total_sum1_v) + self.assertEqual(0, total_sum2_v) + self.assertEqual(0, total_sum3_v) + + +# TODO(jamesqin): Transform to parameterized test after it is included in the +# TF open source codebase. +class CudnnRNNTestSaveRestore(TensorFlowTestCase): + + def _CompareWeights(self, lhs, rhs): + self.assertEqual(len(lhs), len(rhs)) + for lw, rw in zip(lhs, rhs): + self.assertAllEqual(lw, rw) + + def _CompareBiases(self, lhs, rhs, rnn_mode, num_layers, direction): + self.assertEqual(len(lhs), len(rhs)) + if rnn_mode == CUDNN_LSTM: + num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_GRU: + num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_TANH: + num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER + else: + num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2 + num_params_per_layer *= num_dirs + self.assertEqual(num_params_per_layer * num_layers, len(lhs)) + + for i in range(num_layers): + layer_lhs = lhs[i * num_params_per_layer: (i+1) * num_params_per_layer] + layer_rhs = rhs[i * num_params_per_layer: (i+1) * num_params_per_layer] + if direction == CUDNN_RNN_UNIDIRECTION: + self._CompareSingleLayerBiases(layer_lhs, layer_rhs) + else: + size = len(layer_lhs) + fw_lhs, bw_lhs = layer_lhs[:size//2], layer_lhs[size//2:] + fw_rhs, bw_rhs = layer_rhs[:size//2], layer_rhs[size//2:] + self._CompareSingleLayerBiases(fw_lhs, fw_rhs) + self._CompareSingleLayerBiases(bw_lhs, bw_rhs) + + def _CompareSingleLayerBiases(self, lhs, rhs): + self.assertEqual(len(lhs), len(rhs)) + + lf_lhs, rt_lhs = lhs[:len(lhs)//2], lhs[len(lhs)//2:] + lf_rhs, rt_rhs = rhs[:len(rhs)//2], rhs[len(rhs)//2:] + self.assertEqual(len(lf_lhs), len(rt_lhs)) + self.assertEqual(len(lf_rhs), len(rt_rhs)) + + sum_lhs, sum_rhs = [], [] + for lf, rt in zip(lf_lhs, rt_lhs): + sum_lhs.append(lf + rt) + for lf, rt in zip(lf_rhs, rt_rhs): + sum_rhs.append(lf + rt) + self.assertEqual(len(sum_lhs), len(sum_rhs)) + for lf, rt in zip(sum_lhs, sum_rhs): + self.assertAllEqual(lf, rt) + + def _TestSaveRestoreVariable(self, rnn_mode, direction, dtype): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default() as g: + random_seed.set_random_seed(1234) + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + rnn = model.rnn + save_path = os.path.join(self.get_temp_dir(), + "save-restore-variable-test") + saver = saver_lib.Saver() + weights, biases = model.rnn.saveable._OpaqueParamsToCanonical() + opaque_params = rnn.trainable_variables[0] + # CudnnTestModel() creates CudnnOpaqueParamsSaveable that helps saver save + # Cudnn vars in canonical format. + reset_op = state_ops.assign( + opaque_params, + array_ops.zeros(array_ops.shape(opaque_params), dtype=dtype)) + # Passing graph explicitly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + weights_v, biases_v = sess.run([weights, biases]) + + # Reset opaque param + sess.run(reset_op) + saver.restore(sess, save_path) + weights_v_restored, biases_v_restored = sess.run([weights, biases]) + + self._CompareWeights(weights_v, weights_v_restored) + self._CompareBiases(biases_v, biases_v_restored, rnn_mode, num_layers, + direction) + + def _TestSaveRestoreTwoVariables(self, rnn_mode, direction, dtype): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default() as g: + random_seed.set_random_seed(1234) + with vs.variable_scope("m1"): + model1 = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + with vs.variable_scope("m2"): + model2 = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + opaque_params = (model1.rnn.trainable_variables[0], + model2.rnn.trainable_variables[0]) + weights1, biases1 = model1.rnn.saveable._OpaqueParamsToCanonical() + weights2, biases2 = model2.rnn.saveable._OpaqueParamsToCanonical() + reset_params = [ + state_ops.assign(params, + array_ops.zeros_like(params, dtype=dtype)) + for params in opaque_params + ] + reset_op = control_flow_ops.group(*reset_params) + save_path = os.path.join(self.get_temp_dir(), + "save-restore-variable-test2") + saver = saver_lib.Saver() + # Passing graph explicitly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + + weights1_v, biases1_v = sess.run([weights1, biases1]) + weights2_v, biases2_v = sess.run([weights2, biases2]) + + sess.run(reset_op) + saver.restore(sess, save_path) + weights1_v_restored, biases1_v_restored = sess.run([weights1, biases1]) + weights2_v_restored, biases2_v_restored = sess.run([weights2, biases2]) + + self._CompareWeights(weights1_v, weights1_v_restored) + self._CompareWeights(weights2_v, weights2_v_restored) + self._CompareBiases(biases1_v, biases1_v_restored, rnn_mode, num_layers, + direction) + self._CompareBiases(biases2_v, biases2_v_restored, rnn_mode, num_layers, + direction) + + def _TestSaveRestoreOutput(self, rnn_mode, direction, dtype): + with ops.Graph().as_default() as g: + num_layers = 2 + num_units = 7 + input_size = 7 + seq_length = 8 + batch_size = 4 + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=False) + rnn = model.rnn + + save_path = os.path.join(self.get_temp_dir(), "save-restore-output-test") + saver = saver_lib.Saver() + + # Only one opaque var in a cudnn layer. + assert len(rnn.trainable_variables) == 1 + reset_params = state_ops.assign( + rnn.trainable_variables[0], + array_ops.zeros( + array_ops.shape(rnn.trainable_variables[0]), dtype=dtype)) + + # Passing graph explicitly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + inputs, initial_state = model.SynthesizeInput(seq_length, batch_size) + total_sum_v = model.Feed(sess, inputs, initial_state) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + + sess.run(reset_params) + saver.restore(sess, save_path) + total_sum_v_restored = model.Feed(sess, inputs, initial_state) + self.assertAllClose(total_sum_v, total_sum_v_restored, atol=1e-5) + + def _TestSaveRestoreHelper(self, rnn_mode): + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + dtype_list = [dtypes.float32, dtypes.float64] + for direction, dtype in itertools.product(directions, dtype_list): + self._TestSaveRestoreVariable(rnn_mode, direction, dtype) + self._TestSaveRestoreTwoVariables(rnn_mode, direction, dtype) + self._TestSaveRestoreOutput(rnn_mode, direction, dtype) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRepeatedlyCreateCustomSaveable(self): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default(): + random_seed.set_random_seed(1234) + model = CudnnTestModel( + CUDNN_LSTM, + num_layers, + num_units, + input_size, + direction=CUDNN_RNN_UNIDIRECTION, + dtype=dtypes.float32) + with self.assertRaisesRegexp(RuntimeError, + "Cudnn saveable already created"): + model.rnn._create_saveable() + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreLSTM(self): + self._TestSaveRestoreHelper(CUDNN_LSTM) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreGRU(self): + self._TestSaveRestoreHelper(CUDNN_GRU) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRNNTanh(self): + self._TestSaveRestoreHelper(CUDNN_RNN_TANH) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRNNRelu(self): + self._TestSaveRestoreHelper(CUDNN_RNN_RELU) + + +# TODO(jamesqin): Transform to parameterized test after it is included in the +# TF open source codebase. +class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase): + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleLSTM(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_LSTM) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleGRU(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_GRU) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleRNNTanh(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_RNN_TANH) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleRNNRelu(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_RNN_RELU) + + def _TestCudnnCompatibleRnnCellsHelper(self, rnn_mode): + configs = [ + { + "num_layers": 1, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + }, + { + "num_layers": 2, + "seq_length": 8, + "num_units": 4, + "input_size": 8, + "batch_size": 16, + }, + { + "num_layers": 2, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + }, + { + "num_layers": 1, + "seq_length": 2, + "num_units": 2, + "input_size": 4, + "batch_size": 1, + }, + ] + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + for cfg, direction in zip(configs, directions): + self._TestCudnnCompatibleRnnCells(cfg["num_layers"], cfg["seq_length"], + cfg["num_units"], cfg["input_size"], + cfg["batch_size"], rnn_mode, direction) + + def _TestCudnnCompatibleRnnCells(self, num_layers, seq_length, num_units, + input_size, batch_size, rnn_mode, direction): + dtype = dtypes.float32 + # Train graph + with ops.Graph().as_default() as g: + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=True) + target_output = array_ops.placeholder(dtype=dtype) + loss_op = losses.log_loss( + labels=target_output, predictions=model.total_sum) + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2) + train_op = optimizer.minimize(loss_op) + + saver = saver_lib.Saver() + + # Train Cudnn model + seed = 0 + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + # Train 128 steps + num_steps = 128 + for _ in range(num_steps): + inputs, _ = model.SynthesizeInput(seq_length, batch_size, seed) + targets = np.random.rand() + sess.run( + train_op, + feed_dict={ + model.inputs: inputs, + model.initial_state: model.ZeroState(batch_size), + target_output: targets + }) + seed += 1 + + save_path = os.path.join(self.get_temp_dir(), + ("cudnn-rnn-%s-test" % rnn_mode)) + save_v = saver.save(sess, save_path) + self.assertEqual(save_path, save_v) + + # Cudnn inference graph + with ops.Graph().as_default() as g: + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=False) + rnn = model.rnn + saver = saver_lib.Saver() + + inference_input = np.random.rand(seq_length, batch_size, + input_size).astype(np.float32) + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + saver.restore(sess, save_path) + + # Cudnn inference + cudnn_outputs_v, cudnn_output_states_v = model.Feed( + sess, inference_input, return_sum=False) + + # Canonical RNN inference graph + with ops.Graph().as_default() as g: + cell_inputs = array_ops.placeholder( + dtype, shape=[seq_length, batch_size, input_size]) + if direction == CUDNN_RNN_UNIDIRECTION: + # outputs is one tensor, states are num_layer tuples, each 2 tensors + (outputs, states) = _CreateCudnnCompatibleCanonicalRNN(rnn, cell_inputs) + if rnn_mode == CUDNN_LSTM: + output_h = array_ops.stack([s.h for s in states]) + output_c = array_ops.stack([s.c for s in states]) + else: + output_state = array_ops.stack([s for s in states]) + else: + # outputs is one tensor. + # states is a tuple of 2 tuples: + # each sub tuple is num_layer tuples, each with 2 tensors. + (outputs, states) = _CreateCudnnCompatibleCanonicalRNN( + rnn, cell_inputs, is_bidi=True) + output_state_fw, output_state_bw = states + if rnn_mode == CUDNN_LSTM: + output_h, output_c = [], [] + for s_fw, s_bw in zip(output_state_fw, output_state_bw): + output_h.append(array_ops.stack([s_fw.h, s_bw.h])) + output_c.append(array_ops.stack([s_fw.c, s_bw.c])) + output_h = array_ops.concat(output_h, axis=0) + output_c = array_ops.concat(output_c, axis=0) + else: + output_state = [] + for s_fw, s_bw in zip(output_state_fw, output_state_bw): + output_state.append(array_ops.stack([s_fw, s_bw])) + output_state = array_ops.concat(output_state, axis=0) + saver = saver_lib.Saver() + + with self.test_session(use_gpu=True, graph=g) as sess: + saver.restore(sess, save_path) + + # BlockCell inference + if rnn_mode == CUDNN_LSTM: + outputs_v, output_h_v, output_c_v = sess.run( + [outputs, output_h, output_c], + feed_dict={cell_inputs: inference_input}) + self.assertAllClose(cudnn_outputs_v, outputs_v) + cudnn_output_h_v, cudnn_output_c_v = cudnn_output_states_v + self.assertAllClose(cudnn_output_h_v, output_h_v) + self.assertAllClose(cudnn_output_c_v, output_c_v) + else: + outputs_v, output_state_v = sess.run( + [outputs, output_state], + feed_dict={cell_inputs: inference_input}) + self.assertAllClose(cudnn_outputs_v, outputs_v, atol=1e-5, rtol=1e-5) + (cudnn_output_h_v,) = cudnn_output_states_v + self.assertAllClose(cudnn_output_h_v, output_state_v, atol=1e-5, + rtol=1e-5) + + +class CudnnRNNTestParamsSize(TensorFlowTestCase): + + def _TestOpaqueParamsSize(self, rnn_mode, num_layers, num_units, input_size, + direction): + logging.info("Testing one lstm param size with config: %s", locals()) + dtype = dtypes.float32 + + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + dtype=dtype, + direction=direction) + rnn = model.rnn + + # Min param size estimate = sum(weights.size) + sum(biases.size) + min_params_size = ( + np.sum(map(np.prod, rnn.canonical_weight_shapes)) + + np.sum([sp[0] for sp in rnn.canonical_bias_shapes])) + + opaque_params = rnn.trainable_variables[0] + with self.test_session(use_gpu=True, graph=ops.get_default_graph()): + variables.global_variables_initializer().run() + opaque_params_size_v = opaque_params.eval().size + self.assertLessEqual(min_params_size, opaque_params_size_v) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testOpaqueParamsSize(self): + test_configs = [ + [4, 200, 200], + [4, 200, 300], + [4, 200, 100], + [1, 100, 200], + [2, 200, 100], + [3, 200, 400], + ] + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + rnns = [CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH] + for (rnn, config, direction) in itertools.product(rnns, test_configs, + directions): + num_layers, num_units, input_size = config + with ops.Graph().as_default(): + self._TestOpaqueParamsSize(rnn, num_layers, num_units, input_size, + direction) + + +class CudnnRNNTestTraining(TensorFlowTestCase): + + def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1): + """Compute the numeric gradient of y wrt to x. + + Args: + sess: The TF session constructed with a graph containing x and y. + y: A scalar TF Tensor in the graph constructed in sess. + x: A TF Tensor in the graph constructed in sess. + delta: Gradient checker's small perturbation of x[i]. + step: Only compute numerical gradients for a subset of x values. + I.e. dy/dx[i] is computed if i % step == 0. + Returns: + A Tensor of the same shape and dtype as x. If x[i] is not chosen + to compute the numerical gradient dy/x[i], the corresponding + value is set to 0. + """ + + x_data = sess.run(x) + x_size = x_data.size + x_shape = x_data.shape + + numeric_grad = np.zeros(x_size, dtype=x_data.dtype) + + for i in range(0, x_size, step): + x_pos = x_data.copy() + if x_size == 1: + x_pos += delta + else: + x_pos.flat[i] += delta + y_pos_feed_dict = dict([(x.name, x_pos)]) + y_pos = sess.run(y, feed_dict=y_pos_feed_dict) + + x_neg = x_data.copy() + if x_size == 1: + x_neg -= delta + else: + x_neg.flat[i] -= delta + y_neg_feed_dict = dict([(x.name, x_neg)]) + y_neg = sess.run(y, feed_dict=y_neg_feed_dict) + numeric_grad[i] = (y_pos - y_neg) / (2 * delta) + return numeric_grad.reshape(x_shape) + + def _GradientCheck(self, sess, y, xs, tolerance=1e-6, delta=1e-4): + sym_grads_t = gradients.gradients(y, xs) + sym_grads = sess.run(sym_grads_t) + + num_grads = [self._ComputeNumericGrad(sess, y, x, delta) for x in xs] + self.assertEqual(len(sym_grads), len(num_grads)) + for sym, num in zip(sym_grads, num_grads): + self.assertFalse(np.any(np.isnan(sym))) + self.assertFalse(np.any(np.isnan(num))) + self.assertAllClose(sym, num, atol=tolerance, rtol=tolerance) + + def _TestOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, + batch_size, seq_length, dir_count, dropout, dtype, + delta, tolerance): + # Gradient checking runs two forward ops with almost the same input. Need to + # make sure the drop patterns across the two runs are the same. + logging.info("Training test with config: %s", locals()) + old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) + os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) + random_seed.set_random_seed(5678) + has_input_c = (rnn_mode == CUDNN_LSTM) + direction = (CUDNN_RNN_UNIDIRECTION + if dir_count == 1 else CUDNN_RNN_BIDIRECTION) + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dropout=dropout, + dtype=dtype, + training=True, + bias_initializer=init_ops.random_normal_initializer( + mean=1., dtype=dtype)) + rnn = model.rnn + params = rnn.trainable_variables[0] + + inputs = variables.Variable( + random_ops.random_uniform( + [seq_length, batch_size, input_size], dtype=dtype), + dtype=dtype) + input_h = variables.Variable( + random_ops.random_uniform( + [num_layers * dir_count, batch_size, num_units], dtype=dtype), + dtype=dtype) + if has_input_c: + input_c = variables.Variable( + random_ops.random_uniform( + [num_layers * dir_count, batch_size, num_units], dtype=dtype), + dtype=dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + total_sum = model.FProp(inputs, initial_state, training=True) + + with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: + sess.run(variables.global_variables_initializer()) + all_inputs = [inputs, params] + for s in initial_state: + all_inputs.append(s) + self._GradientCheck( + sess, total_sum, all_inputs, tolerance=tolerance, delta=delta) + os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state + + def _TestSimpleTrainingHelper(self, rnn_mode, test_configs): + dropouts = [0., 0.5, 1.] + for config, dropout in itertools.product(test_configs, dropouts): + dtype = config.get("dtype", dtypes.float32) + delta = config.get("delta", 1e-4) + tolerance = config.get("tolerance", 1e-6) + dir_count = config.get("dir_count", 1) + shape = config["shape"] + with ops.Graph().as_default(): + self._TestOneSimpleTraining(rnn_mode, shape["num_layers"], + shape["num_units"], shape["input_size"], + shape["batch_size"], shape["seq_length"], + dir_count, dropout, dtype, delta, tolerance) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingLSTM64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingLSTM32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-4, + "tolerance": 9e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingGRU64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + } + }, + ] + self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingGRU32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-3, + "tolerance": 4e-3, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNTanh64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNTanh32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-3, + "tolerance": 5e-3, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNRelu64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNRelu32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-3, + "tolerance": 7e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..f6c206022c68c4ba78d895f44288f4b180d199c0 --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -0,0 +1,556 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cudnn RNN operators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops +from tensorflow.contrib.util import loader +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as base_layer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging + +_cudnn_rnn_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so")) + +CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION +CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION +CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM +CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU +CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU +CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH + +# Half for cell input, half for hidden states. +CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER +CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER +CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER +CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER + +CUDNN_INPUT_LINEAR_MODE = cudnn_rnn_ops.CUDNN_INPUT_LINEAR_MODE +CUDNN_INPUT_SKIP_MODE = cudnn_rnn_ops.CUDNN_INPUT_SKIP_MODE +CUDNN_INPUT_AUTO_MODE = cudnn_rnn_ops.CUDNN_INPUT_AUTO_MODE + + +class _CudnnRNN(base_layer.Layer): + # pylint:disable=line-too-long + """Abstract class for RNN layers with Cudnn implementation. + + Cudnn RNNs have two major differences from other platform-independent RNNs tf + provides: + * Cudnn LSTM and GRU are mathematically different from their tf counterparts. + (e.g. @{tf.contrib.rnn.LSTMBlockCell} and @{tf.nn.rnn_cell.GRUCell}. + * Cudnn-trained checkpoints are not directly compatible with tf RNNs: + * They use a single opaque parameter buffer for the entire (possibly) + multi-layer multi-directional RNN; Whereas tf RNN weights are per-cell and + layer. + * The size and layout of the parameter buffers may change between + CUDA/CuDNN/GPU generations. Because of that, the opaque parameter variable + does not have a static shape and is not partitionable. Instead of using + partitioning to alleviate the PS's traffic load, try building a + multi-tower model and do gradient aggregation locally within the host + before updating the PS. See https://www.tensorflow.org/performance/performance_models#parameter_server_variables + for a detailed performance guide. + + Consequently, if one plans to use Cudnn trained models on both GPU and CPU + for inference and training, one needs to: + * Create a CudnnOpaqueParamsSaveable subclass object to save RNN params in + canonical format. (This is done for you automatically during layer building + process.) + * When not using a Cudnn RNN class, use CudnnCompatibleRNN classes to load the + checkpoints. These classes are platform-independent and perform the same + computation as Cudnn for training and inference. + Similarly, CudnnCompatibleRNN-trained checkpoints can be loaded by CudnnRNN + classes seamlessly. + + Below is a typical workflow(using LSTM as an example): + for detailed performance guide. + + # Use Cudnn-trained checkpoints with CudnnCompatibleRNNs + ```python + with tf.Graph().as_default(): + lstm = CudnnLSTM(num_layers, num_units, direction, ...) + + outputs, output_states = lstm(inputs, initial_states, training=True) + + # If user plans to delay calling the cell with inputs, one can do + # lstm.build(input_shape) + + saver = Saver() + + # training subgraph + ... + + # Once in a while save the model. + saver.save(save_path) + + # Inference subgraph for unidirectional RNN on, e.g., CPU or mobile. + with tf.Graph().as_default(): + single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units) + + # NOTE: Even if there's only one layer, the cell needs to be wrapped in + # MultiRNNCell. + cell = tf.nn.rnn_cell.MultiRNNCell( + [single_cell() for _ in range(num_layers)]) + + # Leave the scope arg unset. + outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state, ...) + + saver = Saver() + + # Create session + sess = ... + + # Restores + saver.restore(sess, save_path) + + # Inference subgraph for bidirectional RNN + with tf.Graph().as_default(): + single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units) + cells_fw = [single_cell() for _ in range(num_layers)] + cells_bw = [single_cell() for _ in range(num_layers)] + + # Leave the scope arg unset. + (outputs, output_state_fw, + output_state_bw) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( + cells_fw, cells_bw, inputs, ...) + saver = Saver() + + # Create session + sess = ... + + # Restores + saver.restore(sess, save_path) + ``` + """ + # pylint:enable=line-too-long + + # The following are constants defined by subclasses. + # Type of RNN cell. + _rnn_mode = None + # Number of cell weights(or biases) per layer. + _num_params_per_layer = None + # Custom SaveableObject class for the CudnnRNN class. + _saveable_cls = None + + # TODO(jamesqin): support float16 CuDNN RNN + def __init__(self, + num_layers, + num_units, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + seed=None, + dtype=dtypes.float32, + kernel_initializer=None, + bias_initializer=None, + name=None): + """Creates a CudnnRNN model from model spec. + + Args: + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It can be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Can be either + 'unidirectional' or 'bidirectional' + dropout: dropout rate, a number between [0, 1]. Dropout is applied on + inputs of each layer. When set to 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + dtype: tf.float32 or tf.float64 + kernel_initializer: starting value to initialize the weight. + bias_initializer: starting value to initialize the bias + (default is all zeros). + name: VariableScope for the created subgraph; defaults to class name. + This only serves the default scope if later no scope is specified when + invoking __call__(). + + Raises: + ValueError: if direction is invalid. Or dtype is not supported. + """ + super(_CudnnRNN, self).__init__(dtype=dtype, name=name) + cudnn_rnn_ops.check_direction(direction) + cudnn_rnn_ops.check_input_mode(input_mode) + + if dtype not in [dtypes.float32, dtypes.float64]: + raise ValueError("Only support float32, float64, provided %s" % dtype) + # Layer self.dtype is type name, the original DType object is kept here. + self._plain_dtype = dtype + self._num_layers = num_layers + self._num_units = num_units + self._input_mode = input_mode + self._direction = direction + self._dropout = dropout + self._seed = seed + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + # Init input_size to None, which will be set after build(). + self._input_size = None + self._saveable = None + + @property + def num_layers(self): + return self._num_layers + + @property + def num_units(self): + return self._num_units + + @property + def input_mode(self): + """Input mode of first layer. + + Indicates whether there is a linear projection between the input and the + actual computation before the first layer. It can be + * 'linear_input': (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior) + * 'skip_input': 'skip_input' is only allowed when input_size == num_units. + * 'auto_select'. implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + + Returns: + 'linear_input', 'skip_input' or 'auto_select'. + """ + return self._input_mode + + @property + def input_size(self): + if not self._input_size: + raise ValueError( + "\'input_size\' is unknown since layer has not been built.") + return self._input_size + + @property + def rnn_mode(self): + """Type of RNN cell used. + + Returns: + `lstm`, `gru`, `rnn_relu` or `rnn_tanh`. + """ + return self._rnn_mode + + @property + def direction(self): + """Returns `unidirectional` or `bidirectional`.""" + return self._direction + + @property + def num_dirs(self): + return 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + + @property + def saveable(self): + return self._saveable + + @property + def canonical_weight_shapes(self): + """Shapes of Cudnn canonical weight tensors.""" + if not self._input_size: + raise RuntimeError( + "%s.canonical_weight_shapes invoked before input shape is known" % + type(self).__name__) + + shapes = [] + for i in range(self._num_layers): + shapes.extend(self._canonical_weight_shape(i)) + return shapes + + @property + def canonical_bias_shapes(self): + """Shapes of Cudnn canonical bias tensors.""" + return self._canonical_bias_shape(0) * self._num_layers + + def _update_trainable_weights(self, getter, *args, **kwargs): + """Custom getter for layer variables.""" + # Add variables to layer's `(non_)trainable_weights` list(s). + variable = getter(*args, **kwargs) + trainable = kwargs.get("trainable", True) + if trainable and variable not in self._trainable_weights: + self._trainable_weights.append(variable) + elif not trainable and variable not in self._non_trainable_weights: + self._non_trainable_weights.append(variable) + return variable + + def build(self, input_shape): + """Create variables of the Cudnn RNN. + + It can be called manually before `__call__()` or automatically through + `__call__()`. In the former case, subsequent `__call__()`s will skip + creating variables. + Args: + input_shape: network input tensor shape, a python list or a TensorShape + object with 3 dimensions. + Raises: + ValueError: if input_shape has wrong dimension or unknown 3rd dimension. + """ + if self.built: + return + + input_shape = tensor_shape.TensorShape(input_shape) + if input_shape.ndims != 3: + raise ValueError("Expecting input_shape with 3 dims, got %d" % + input_shape.ndims) + if input_shape[-1].value is None: + raise ValueError("The last dimension of the inputs to `CudnnRNN` " + "should be defined. Found `None`.") + self._input_size = input_shape[-1].value + self.input_spec = base_layer.InputSpec(ndim=3, axes={-1: self._input_size}) + + self._set_scope(None) + + # Not using base class `add_variable()` since the it calls + # `tf.get_variable()` with a callable initializer whereas here with a + # tensor. The difference is mandated to support forward-compatibility with + # Cudnn. + with vs.variable_scope( + self._scope, + reuse=self.built, + custom_getter=self._update_trainable_weights): + if self._kernel_initializer is None: + self._kernel_initializer = init_ops.glorot_uniform_initializer( + seed=self._seed, dtype=self._plain_dtype) + if self._bias_initializer is None: + self._bias_initializer = init_ops.constant_initializer( + 0.0, dtype=self._plain_dtype) + + weights = [ + self._kernel_initializer(sp, dtype=self._plain_dtype) + for sp in self.canonical_weight_shapes + ] + biases = [ + self._bias_initializer(sp, dtype=self._plain_dtype) + for sp in self.canonical_bias_shapes + ] + opaque_params_t = self._canonical_to_opaque(weights, biases) + + if vs.get_variable_scope().partitioner is not None: + logging.warn( + "Partitioner is not supported for Cudnn RNN layer variables, using " + "it will create forward-compatibility issues with future " + "CUDA/CuDNN generations.") + # Initialize opaque params with a tensor. + self.kernel = vs.get_variable( + "opaque_kernel", initializer=opaque_params_t, validate_shape=False) + # Create saveable in the outer scope of the cudnn subgraph, such that + # alternative subgraph with platform-independent rnn cells can load the + # checkpoints directly. + if not (self.built or vs.get_variable_scope().reuse): + self._create_saveable() + self.built = True + + def call(self, inputs, initial_state=None, training=True): + """Runs the forward step for the RNN model. + + Args: + inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`. + initial_state: a tuple of tensor(s) of shape + `[num_layers * num_dirs, batch_size, num_units]`. If not provided, use + zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs. + training: whether this operation will be used in training or inference. + Returns: + output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`. + It is a `concat([fwd_output, bak_output], axis=2)`. + output_states: a tuple of tensor(s) of the same shape and structure as + `initial_state`. + Raises: + ValueError: initial_state is not a tuple. + """ + if initial_state is not None and not isinstance(initial_state, tuple): + raise ValueError("Invalid initial_state type: %s, expecting tuple.", + type(initial_state)) + dtype = self.dtype + inputs = ops.convert_to_tensor(inputs, dtype=dtype) + + batch_size = array_ops.shape(inputs)[1] + if initial_state is None: + initial_state = self._zero_state(batch_size) + if self._rnn_mode == CUDNN_LSTM: + h, c = initial_state # pylint:disable=unbalanced-tuple-unpacking,unpacking-non-sequence + else: + h, = initial_state # pylint:disable=unbalanced-tuple-unpacking,unpacking-non-sequence + h = ops.convert_to_tensor(h, dtype=dtype) + if self._rnn_mode == CUDNN_LSTM: + c = ops.convert_to_tensor(c, dtype=dtype) + else: + # For model that doesn't take input_c, replace with a dummy tensor. + c = array_ops.constant([], dtype=dtype) + outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel, + training) + if self._rnn_mode == CUDNN_LSTM: + return outputs, (output_h, output_c) + else: + return outputs, (output_h,) + + def state_shape(self, batch_size): + raise NotImplementedError + + def _zero_state(self, batch_size): + res = [] + for sp in self.state_shape(batch_size): + res.append(array_ops.zeros(sp, dtype=self.dtype)) + return tuple(res) + + def _canonical_weight_shape(self, layer): + """Shapes of Cudnn canonical weight tensors for given layer.""" + if layer < 0 or layer >= self._num_layers: + raise ValueError("\'layer\' is not valid, got %s, expecting [%d, %d]" % + (layer, 0, self._num_layers-1)) + if not self._input_size: + raise RuntimeError( + "%s._canonical_weight_shape invoked before input shape is known" % + type(self).__name__) + + input_size = self._input_size + num_units = self._num_units + num_gates = self._num_params_per_layer // 2 + is_bidi = self._direction == CUDNN_RNN_BIDIRECTION + + if layer == 0: + wts_applied_on_inputs = [(num_units, input_size)] * num_gates + else: + if is_bidi: + wts_applied_on_inputs = [(num_units, 2 * num_units)] * num_gates + else: + wts_applied_on_inputs = [(num_units, num_units)] * num_gates + wts_applied_on_hidden_states = [(num_units, num_units)] * num_gates + tf_wts = wts_applied_on_inputs + wts_applied_on_hidden_states + return tf_wts if not is_bidi else tf_wts * 2 + + def _canonical_bias_shape(self, unused_layer): + """Shapes of Cudnn canonical bias tensors for given layer.""" + num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + return [[self._num_units]] * num_dirs * self._num_params_per_layer + + def _canonical_to_opaque(self, cu_weights, cu_biases): + if not self._input_size: + raise RuntimeError( + "%s._canonical_to_opaque invoked before input shape is known" % + type(self).__name__) + return cudnn_rnn_ops.cudnn_rnn_canonical_to_opaque_params( + rnn_mode=self._rnn_mode, + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + input_mode=self._input_mode, + direction=self._direction) + + def _forward(self, inputs, h, c, opaque_params, training): + output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access + inputs, + h, + c, + opaque_params, + training, + self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction, + dropout=self._dropout, + seed=self._seed) + return output, (output_h, output_c) + + def _create_saveable(self): + """Create custom saveable for the Cudnn layer. + + Called during layer building process to make sharing checkpoints between + Cudnn and Cudnn-compatible RNNs easy. + Returns: + a `CudnnOpaqueParamsSaveable` object. + Raises: + RuntimeError: if any custom saveable is already created for this layer. + """ + if self._saveable is not None: + raise RuntimeError("Cudnn saveable already created.") + self._saveable = self._saveable_cls( # pylint:disable=not-callable + self.trainable_variables[0], + self.num_layers, + self.num_units, + self.input_size, + self.input_mode, + self.direction, + scope=vs.get_variable_scope(), + name="%s_saveable" % self.trainable_variables[0].op.name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) + + +class CudnnLSTM(_CudnnRNN): + """Cudnn implementation of LSTM layer.""" + _rnn_mode = CUDNN_LSTM + _num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnLSTMSaveable + + def state_shape(self, batch_size): + """Shape of Cudnn LSTM states. + + Shape is a 2-element tuple. Each is + [num_layers * num_dirs, batch_size, num_units] + Args: + batch_size: an int + Returns: + a tuple of python arrays. + """ + return ([self.num_layers * self.num_dirs, batch_size, self.num_units], + [self.num_layers * self.num_dirs, batch_size, self.num_units]) + + +class _CudnnRNNNoInputC(_CudnnRNN): + """Abstract simple CudnnRNN layer without input_c.""" + + def state_shape(self, batch_size): + """Shape of the state of Cudnn RNN cells w/o. input_c. + + Shape is a 1-element tuple, + [num_layers * num_dirs, batch_size, num_units] + Args: + batch_size: an int + Returns: + a tuple of python arrays. + """ + return [self.num_layers * self.num_dirs, batch_size, self.num_units], + + +class CudnnGRU(_CudnnRNNNoInputC): + """Cudnn implementation of the GRU layer.""" + _rnn_mode = CUDNN_GRU + _num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnGRUSaveable + + +class CudnnRNNTanh(_CudnnRNNNoInputC): + """Cudnn implementation of the RNN-tanh layer.""" + _rnn_mode = CUDNN_RNN_TANH + _num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnRNNTanhSaveable + + +class CudnnRNNRelu(_CudnnRNNNoInputC): + """Cudnn implementation of the RNN-relu layer.""" + _rnn_mode = CUDNN_RNN_RELU + _num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnRNNReluSaveable diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index f6eeb016755b66a8ac2a4b4e711543ebdf468269..7d658c746ee1ecd21cefca9c9e52f611869f6176 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -65,7 +65,7 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): def __init__(self, num_units, reuse=None): super(CudnnCompatibleLSTMCell, self).__init__( - num_units, forget_bias=0, clip_cell=False, use_peephole=False, + num_units, forget_bias=0, cell_clip=None, use_peephole=False, reuse=reuse) self._names.update({"scope": "cudnn_compatible_lstm_cell"}) @@ -717,12 +717,6 @@ _cudnn_rnn_common_doc_string = """ """ -def _check_direction(direction): - if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): - raise ValueError("Invalid direction: %s, expect %s or %s" % - (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)) - - def _check_rnn_mode(rnn_mode): if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU): raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" % @@ -737,14 +731,31 @@ def _get_seed(seed): return seed, seed2 +def check_direction(direction): + """Check validity of direction.""" + if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): + raise ValueError("Invalid direction: %s, expecting %s or %s" % + (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)) + + +def check_input_mode(input_mode): + if input_mode not in (CUDNN_INPUT_LINEAR_MODE, CUDNN_INPUT_SKIP_MODE, + CUDNN_INPUT_AUTO_MODE): + raise ValueError("Invalid input_mode: %s, expect one of (%s, %s, %s)" % + (input_mode, CUDNN_INPUT_LINEAR_MODE, + CUDNN_INPUT_SKIP_MODE, CUDNN_INPUT_AUTO_MODE)) + + def _get_num_params(rnn_mode, num_layers, direction): """Return num params for given Cudnn config.""" if rnn_mode == CUDNN_LSTM: - num_params_per_layer = 8 + num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER elif rnn_mode == CUDNN_GRU: - num_params_per_layer = 6 - elif rnn_mode in (CUDNN_RNN_RELU, CUDNN_RNN_TANH): - num_params_per_layer = 2 + num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_RELU: + num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_TANH: + num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER else: raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode) num_params = num_layers * num_params_per_layer @@ -794,7 +805,8 @@ def _cudnn_rnn(inputs, outputs, output_h, output_c """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn( input=inputs, @@ -1017,16 +1029,16 @@ def cudnn_rnn_tanh(inputs, seed, name) -def cudnn_rnn_params_to_canonical(rnn_mode, - num_layers, - num_units, - input_size, - params, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_opaque_params_to_canonical(rnn_mode, + num_layers, + num_units, + input_size, + params, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0, + seed=0, + name=None): """Convert cudnn opaque params to canonical. Args: @@ -1058,7 +1070,8 @@ def cudnn_rnn_params_to_canonical(rnn_mode, """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) num_params = _get_num_params(rnn_mode, num_layers, direction) seed, seed2 = random_seed.get_seed(seed) weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( @@ -1077,17 +1090,17 @@ def cudnn_rnn_params_to_canonical(rnn_mode, return weights, biases -def cudnn_rnn_canonical_to_params(rnn_mode, - num_layers, - num_units, - input_size, - weights, - biases, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_canonical_to_opaque_params(rnn_mode, + num_layers, + num_units, + input_size, + weights, + biases, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0, + seed=0, + name=None): """Converts params from the canonical format to a specific format of cuDNN. Args: @@ -1119,7 +1132,8 @@ def cudnn_rnn_canonical_to_params(rnn_mode, ValueError: if rnn_mode or direction is invalid. """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( rnn_mode=rnn_mode, @@ -1136,16 +1150,16 @@ def cudnn_rnn_canonical_to_params(rnn_mode, name=name) -def cudnn_opaque_params_size(rnn_mode, - num_layers, - num_units, - input_size, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dtype=dtypes.float32, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_opaque_params_size(rnn_mode, + num_layers, + num_units, + input_size, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dtype=dtypes.float32, + dropout=0, + seed=0, + name=None): """Returns opaque params size for specific Cudnn config. Args: @@ -1176,7 +1190,8 @@ def cudnn_opaque_params_size(rnn_mode, ValueError: if rnn_mode or direction is invalid. """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) return gen_cudnn_rnn_ops.cudnn_rnn_params_size( rnn_mode=rnn_mode, @@ -1278,7 +1293,7 @@ class _CudnnRNN(object): Returns: The calculated parameter buffer size. """ - return cudnn_opaque_params_size( + return cudnn_rnn_opaque_params_size( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, @@ -1327,7 +1342,7 @@ class _CudnnRNN(object): Returns: A function for the specific-to-canonical conversion. """ - return cudnn_rnn_params_to_canonical( + return cudnn_rnn_opaque_params_to_canonical( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, @@ -1348,7 +1363,7 @@ class _CudnnRNN(object): Returns: A function for the canonical-to-params-to-specific conversion.. """ - return cudnn_rnn_canonical_to_params( + return cudnn_rnn_canonical_to_opaque_params( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 2557eb4fc26d0852d759349eddf5ade07f4d8e04..ee96269a739ebb138ea88cf4e192f7925e85447d 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -13,7 +13,7 @@ py_library( "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", ], ) diff --git a/tensorflow/contrib/data/README.md b/tensorflow/contrib/data/README.md index 7c59a1ffc37085f17f8f4e693c0bc874c77f914a..30e909111f460bb4d0ea5fcdefaf5bdedc93b9c0 100644 --- a/tensorflow/contrib/data/README.md +++ b/tensorflow/contrib/data/README.md @@ -1,8 +1,39 @@ `tf.contrib.data` API ===================== -This directory contains the Python API for the `tf.contrib.data.Dataset` and -`tf.contrib.data.Iterator` classes, which can be used to build input pipelines. +NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead. +We are continuing to support existing code using the `tf.contrib.data` APIs in +the current version of TensorFlow, but will eventually remove support. The +`tf.data` APIs are subject to backwards compatibility guarantees. -The documentation for this API has moved to the programmers' -guide, [here](../../docs_src/programmers_guide/datasets.md). +Porting your code to `tf.data` +------------------------------ + +The `tf.contrib.data.Dataset` class has been renamed to `tf.data.Dataset`, and +the `tf.contrib.data.Iterator` class has been renamed to `tf.data.Iterator`. +Most code can be ported by removing `.contrib` from the names of the classes. +However, there are some small differences, which are outlined below. + +The arguments accepted by the `Dataset.map()` transformation have changed: + +* `dataset.map(..., num_threads=T)` is now `dataset.map(num_parallel_calls=T)`. +* `dataset.map(..., output_buffer_size=B)` is now + `dataset.map(...).prefetch(B). + +Some transformations have been removed from `tf.data.Dataset`, and you must +instead apply them using `Dataset.apply()` transformation. The full list of +changes is as follows: + +* `dataset.dense_to_sparse_batch(...)` is now + `dataset.apply(tf.contrib.data.dense_to_sparse_batch(...)`. +* `dataset.enumerate(...)` is now + `dataset.apply(tf.contrib.data.enumerate_dataset(...))`. +* `dataset.group_by_window(...)` is now + `dataset.apply(tf.contrib.data.group_by_window(...))`. +* `dataset.ignore_errors()` is now + `dataset.apply(tf.contrib.data.ignore_errors())`. +* `dataset.unbatch()` is now `dataset.apply(tf.contrib.data.unbatch())`. + +The `Dataset.make_dataset_resource()` and `Iterator.dispose_op()` methods have +been removed from the API. Please open a GitHub issue if you have a need for +either of these. diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index b930bfa0b726cc0a44b4b601b604e99444e481fd..7ff26e087bbb61963948be1a2edaaa407d0ba1f8 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -32,6 +32,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@rejection_resample @@sloppy_interleave +@@get_single_element """ from __future__ import absolute_import @@ -44,6 +45,7 @@ from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch from tensorflow.contrib.data.python.ops.batching import unbatch from tensorflow.contrib.data.python.ops.dataset_ops import Dataset +from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.grouping import group_by_window @@ -54,7 +56,7 @@ from tensorflow.contrib.data.python.ops.readers import TextLineDataset from tensorflow.contrib.data.python.ops.readers import TFRecordDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave -from tensorflow.python.data.ops.dataset_ops import Iterator +from tensorflow.python.data.ops.iterator_ops import Iterator # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 61a067ec4256666e25d5d8fe8436214471056248..c34c9dad9b5afb1f1232c8bff4c26770199ce7b6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -62,6 +62,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", ], ) @@ -160,6 +161,7 @@ py_test( "//tensorflow/python:function", "//tensorflow/python:functional_ops", "//tensorflow/python:session", + "//tensorflow/python/data/ops:iterator_ops", ], ) @@ -188,6 +190,7 @@ py_test( "//tensorflow/python:script_ops", "//tensorflow/python:session", "//tensorflow/python:training", + "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", ], ) @@ -252,6 +255,7 @@ py_test( "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", ], ) @@ -261,7 +265,6 @@ py_test( srcs = ["reader_dataset_ops_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -275,6 +278,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/ops:iterator_ops", ], ) @@ -338,6 +342,7 @@ py_test( "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 91f100e0f0ccd452ab9a9e673d6714b718ccbeb2..add17ff8bcea0f228dc36ec6157fe95b9ce44d80 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -256,8 +256,9 @@ class BatchDatasetTest(test.TestCase): def testDenseToSparseBatchDatasetWithUnknownShape(self): components = np.random.randint(5, size=(40,)).astype(np.int32) iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([x, x], x)).dense_to_sparse_batch( - 4, [5, -1]).make_initializable_iterator()) + .map(lambda x: array_ops.fill([x, x], x)).apply( + batching.dense_to_sparse_batch( + 4, [5, -1])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -285,7 +286,8 @@ class BatchDatasetTest(test.TestCase): def testDenseToSparseBatchDatasetWithInvalidShape(self): input_tensor = array_ops.constant([[1]]) iterator = (dataset_ops.Dataset.from_tensors(input_tensor) - .dense_to_sparse_batch(4, [-2]).make_initializable_iterator()) + .apply(batching.dense_to_sparse_batch(4, [-2])) + .make_initializable_iterator()) init_op = iterator.initializer with self.test_session() as sess: @@ -424,6 +426,102 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) + def testBatchAndMapDataset(self): + """Test a dataset that maps a TF function across its input elements.""" + # The pipeline is TensorSliceDataset -> + # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) + + count = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + iterator = (dataset_ops.Dataset.from_tensor_slices(components).repeat(count) + .apply(batching.map_and_batch(_map_fn, batch_size)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([[None] + list(c.shape[1:]) for c in components], + [t.shape.as_list() for t in get_next]) + + with self.test_session() as sess: + # Batch of a finite input, where the batch_size divides the + # total number of elements. + sess.run(init_op, feed_dict={count: 28, batch_size: 14}) + num_batches = (28 * 7) // 14 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(14): + self.assertAllEqual(component[(i*14 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Batch of a finite input, where the batch_size does not + # divide the total number of elements. + sess.run(init_op, feed_dict={count: 14, batch_size: 8}) + + # We expect (num_batches - 1) full-sized batches. + num_batches = int(math.ceil((14 * 7) / 8)) + for i in range(num_batches - 1): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(8): + self.assertAllEqual(component[(i*8 + j) % 7]**2, + result_component[j]) + # The last batch should fail with `OutOfRange`. + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Batch of an empty input should fail straight away. + sess.run(init_op, feed_dict={count: 0, batch_size: 8}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Empty batch should be an initialization time error. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + + def testBatchAndMapDatasetFails(self): + """Test a dataset that maps a TF function across its input elements.""" + dataset = dataset_ops.Dataset.from_tensors( + array_ops.check_numerics( + constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = (dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + .make_initializable_iterator()) + init_op = iterator.initializer + with self.test_session() as sess: + with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): + sess.run(init_op, feed_dict={batch_size: 14}) + + def testBatchAndMapDatasetShapeMismatch(self): + """Test a dataset that maps a TF function across its input elements.""" + def generator(): + yield [1] + yield [2] + yield [3] + yield [[4, 5, 6]] + + dataset = dataset_ops.Dataset.from_generator( + generator, output_types=dtypes.int32) + batch_size = 4 + iterator = ( + dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "number of elements does not match"): + sess.run(get_next) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py index 364c1be8eafccb77d4a54241ed758fc6cadbd00b..9818020680afb9d0f0197d272ec5339c6358db36 100644 --- a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py @@ -24,6 +24,7 @@ import tempfile import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -59,8 +60,8 @@ class FilesystemCacheDatasetTest(test.TestCase): # Create initialization ops for iterators without and with # caching, respectively. - iterator = dataset_ops.Iterator.from_structure(cache_dataset.output_types, - cache_dataset.output_shapes) + iterator = iterator_ops.Iterator.from_structure(cache_dataset.output_types, + cache_dataset.output_shapes) init_fifo_op = iterator.make_initializer(repeat_dataset) init_cache_op = iterator.make_initializer(cache_dataset) diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index f74362d4e8237ba8aa2522d4438f89a5b5dea448..a66714feda98d24778d9049b19455f28e4f76197 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -434,6 +434,30 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testFromGeneratorImplicitConversion(self): + def generator(): + yield [1] + yield [2] + yield [3] + + for dtype in [dtypes.int8, dtypes.int32, dtypes.int64]: + iterator = (dataset_ops.Dataset.from_generator( + generator, output_types=dtype, output_shapes=[1]) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual(dtype, get_next.dtype) + + with self.test_session() as sess: + sess.run(init_op) + for expected in [[1], [2], [3]]: + next_val = sess.run(get_next) + self.assertEqual(dtype.as_numpy_dtype, next_val.dtype) + self.assertAllEqual(expected, next_val) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testFromGeneratorTypeError(self): def generator(): yield np.array([1, 2, 3], dtype=np.int64) @@ -451,7 +475,7 @@ class DatasetConstructorTest(test.TestCase): sess.run(init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) self.assertAllEqual([4, 5, 6], sess.run(get_next)) - with self.assertRaisesOpError(r"element of type .*int64.* was expected"): + with self.assertRaisesOpError(r"invalid literal for long\(\)"): sess.run(get_next) self.assertAllEqual([7, 8, 9], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py index faad6e925d78e273d4c308d42598aa12edc792e2..02379d064d4ab857ce9c7d13881a3ae37eea0980 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function @@ -44,7 +45,7 @@ class IteratorClusterTest(test.TestCase): iterator_3_handle = iterator_3.string_handle() with ops.device("/job:worker/replica:0/task:0/cpu:0"): - remote_it = dataset_ops.Iterator.from_string_handle( + remote_it = iterator_ops.Iterator.from_string_handle( iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes) get_next_op = remote_it.get_next() @@ -52,24 +53,19 @@ class IteratorClusterTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next_op) - def testRemoteIteratorUsingRemoteCallOp(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 - worker, _ = test_util.create_local_cluster( - 1, 1, worker_config=worker_config) - - with ops.device("/job:worker/replica:0/task:0/cpu:1"): + def _testRemoteIteratorHelper(self, device0, device1, target): + with ops.device(device1): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_3.make_one_shot_iterator() iterator_3_handle = iterator_3.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): - remote_iterator = dataset_ops.Iterator.from_string_handle( + remote_iterator = iterator_ops.Iterator.from_string_handle( h, dataset_3.output_types, dataset_3.output_shapes) return remote_iterator.get_next() - with ops.device("/job:worker/replica:0/task:0/cpu:0"): + with ops.device(device0): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) remote_op = functional_ops.remote_call( args=[iterator_3_handle], @@ -77,32 +73,35 @@ class IteratorClusterTest(test.TestCase): f=_remote_fn, target=target_placeholder) - with session.Session(worker[0].target) as sess: - elem = sess.run( - remote_op, - feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"}) + with session.Session(target) as sess: + elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) self.assertEqual(elem, [1]) # Fails when target is cpu:0 where the resource is not located. with self.assertRaises(errors.InvalidArgumentError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:worker/replica:0/task:0/cpu:0" - }) - elem = sess.run( - remote_op, - feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"}) + sess.run(remote_op, feed_dict={target_placeholder: device0}) + elem = sess.run(iterator_3.get_next()) self.assertEqual(elem, [2]) - elem = sess.run( - remote_op, - feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"}) + elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:worker/replica:0/task:0/cpu:1" - }) + sess.run(remote_op, feed_dict={target_placeholder: device1}) + + def testRemoteIteratorUsingRemoteCallOp(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + worker, _ = test_util.create_local_cluster( + 1, 1, worker_config=worker_config) + + self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", + "/job:worker/replica:0/task:0/cpu:1", + worker[0].target) + + def testRemoteIteratorUsingRemoteCallOpCrossProcess(self): + workers, _ = test_util.create_local_cluster(2, 1) + + self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", + "/job:worker/replica:0/task:1/cpu:0", + workers[0].target) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 2b947766b94a75b8f999b0356b70c16eab8f175b..20f6d6ba34f49fa99d42961a6aa68ffed6b4f657 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -239,7 +240,7 @@ class IteratorTest(test.TestCase): # functions in this graph, to ensure that we are not # accidentally redefining functions with the same names in the # new graph. - iterator = dataset_ops.Iterator.from_structure( + iterator = iterator_ops.Iterator.from_structure( shared_name="shared_iterator", output_types=(dtypes.int64, dtypes.int64, dtypes.float64), output_shapes=([], [3], [])) @@ -269,8 +270,8 @@ class IteratorTest(test.TestCase): constant_op.constant([1, 2, 3])) dataset_4 = dataset_ops.Dataset.from_tensors( constant_op.constant([4, 5, 6, 7])) - iterator = dataset_ops.Iterator.from_structure(dataset_3.output_types, - [None]) + iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types, + [None]) dataset_3_init_op = iterator.make_initializer(dataset_3) dataset_4_init_op = iterator.make_initializer(dataset_4) @@ -306,12 +307,12 @@ class IteratorTest(test.TestCase): def testReinitializableIteratorStaticErrors(self): # Non-matching structure for types and shapes. with self.assertRaises(TypeError): - iterator = dataset_ops.Iterator.from_structure((dtypes.int64, - dtypes.float64), [None]) + iterator = iterator_ops.Iterator.from_structure((dtypes.int64, + dtypes.float64), [None]) # Test validation of dataset argument. - iterator = dataset_ops.Iterator.from_structure((dtypes.int64, - dtypes.float64)) + iterator = iterator_ops.Iterator.from_structure((dtypes.int64, + dtypes.float64)) # Incompatible structure. with self.assertRaises(ValueError): @@ -328,7 +329,7 @@ class IteratorTest(test.TestCase): [4., 5., 6., 7.], dtype=dtypes.float32)))) # Incompatible shapes. - iterator = dataset_ops.Iterator.from_structure( + iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), ([None], [])) with self.assertRaises(TypeError): iterator.make_initializer( @@ -344,7 +345,7 @@ class IteratorTest(test.TestCase): iterator_4 = dataset_4.make_one_shot_iterator() handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - feedable_iterator = dataset_ops.Iterator.from_string_handle( + feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_3.output_types, dataset_3.output_shapes) next_element = feedable_iterator.get_next() @@ -391,11 +392,11 @@ class IteratorTest(test.TestCase): handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - feedable_int_scalar = dataset_ops.Iterator.from_string_handle( + feedable_int_scalar = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, []) - feedable_int_vector = dataset_ops.Iterator.from_string_handle( + feedable_int_vector = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, [None]) - feedable_int_any = dataset_ops.Iterator.from_string_handle( + feedable_int_any = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32) with self.test_session() as sess: @@ -435,7 +436,7 @@ class IteratorTest(test.TestCase): @function.Defun(dtypes.string) def _remote_fn(h): - remote_iterator = dataset_ops.Iterator.from_string_handle( + remote_iterator = iterator_ops.Iterator.from_string_handle( h, dataset_3.output_types, dataset_3.output_shapes) return remote_iterator.get_next() @@ -495,7 +496,7 @@ class IteratorTest(test.TestCase): @function.Defun(dtypes.uint8) def _remote_fn(h): handle = script_ops.py_func(_encode_raw, [h], dtypes.string) - remote_iterator = dataset_ops.Iterator.from_string_handle( + remote_iterator = iterator_ops.Iterator.from_string_handle( handle, dataset_3.output_types, dataset_3.output_shapes) return remote_iterator.get_next() @@ -583,6 +584,31 @@ class IteratorTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(restore_op) + def testToSingleElement(self): + skip_value = array_ops.placeholder(dtypes.int64, shape=[]) + take_value = array_ops.placeholder_with_default( + constant_op.constant(1, dtype=dtypes.int64), shape=[]) + + dataset = (dataset_ops.Dataset.range(100) + .skip(skip_value) + .map(lambda x: x * x) + .take(take_value)) + + element = dataset_ops.get_single_element(dataset) + + with self.test_session() as sess: + self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) + self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) + self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset was empty."): + sess.run(element, feed_dict={skip_value: 100}) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset had more than one element."): + sess.run(element, feed_dict={skip_value: 0, take_value: 2}) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index fce418c2ab9a792cbce991d3d5600b80bc41a634..8a1d99499be702d91f87f65f443261b47ce5c5cd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -182,7 +182,9 @@ class MapDatasetTest(test.TestCase): (1, 1), (1, 2), (2, 2), (2, 4), (8, 8), (8, 16)]: do_test(num_threads_val, output_buffer_size_val) - def _testDisposeParallelMapDataset(self, explicit_dispose): + def testImplicitDisposeParallelMapDataset(self): + # Tests whether a parallel map dataset will be cleaned up correctly when + # the pipeline does not run it until exhaustion. # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(1000). components = (np.arange(1000), @@ -195,21 +197,11 @@ class MapDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - if explicit_dispose: - dispose_op = iterator.dispose_op() with self.test_session() as sess: sess.run(init_op) for _ in range(3): sess.run(get_next) - if explicit_dispose: - sess.run(dispose_op) - - def testExplicitDisposeParallelMapDataset(self): - self._testDisposeParallelMapDataset(True) - - def testImplicitDisposeParallelMapDataset(self): - self._testDisposeParallelMapDataset(False) def testParallelMapUnspecifiedOutputSize(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index ecb6ab8171d5d63c00c304c483ce06a700db504e..c8a0072809c2eac30e255d29ecaee5a324449045 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -21,6 +21,7 @@ import os from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -279,8 +280,8 @@ class RangeDatasetTest(test.TestCase): # Create an empty IteratorResource and restore the Iterator into it. output_types = dtypes.int64 output_shapes = tensor_shape.scalar() - iterator = dataset_ops.Iterator.from_structure(output_types, - output_shapes) + iterator = iterator_ops.Iterator.from_structure(output_types, + output_shapes) restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, path) get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 1f27a2d70483adf66c4ae1fc1af6be5f7014ad61..c9f88f3dfc9a062ccd0bcabe7eadf18c98191c1d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,10 +21,10 @@ import gzip import os import zlib -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -87,7 +87,7 @@ class TextLineDatasetTest(test.TestCase): filenames, compression_type=compression_type).repeat(num_epochs) batch_dataset = repeat_dataset.batch(batch_size) - iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types) + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) init_op = iterator.make_initializer(repeat_dataset) init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() @@ -199,7 +199,7 @@ class FixedLengthRecordReaderTest(test.TestCase): .repeat(num_epochs)) batch_dataset = repeat_dataset.batch(batch_size) - iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types) + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) init_op = iterator.make_initializer(repeat_dataset) init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() @@ -293,7 +293,7 @@ class FixedLengthRecordReaderTest(test.TestCase): def _restore_iterator(self): output_types = dtypes.string output_shapes = tensor_shape.scalar() - iterator = dataset_ops.Iterator.from_structure(output_types, output_shapes) + iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) get_next = iterator.get_next() restore_op = gen_dataset_ops.restore_iterator( iterator._iterator_resource, self._iterator_checkpoint_path()) @@ -575,7 +575,7 @@ class TFRecordDatasetTest(test.TestCase): self.num_epochs) batch_dataset = repeat_dataset.batch(self.batch_size) - iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types) + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) self.init_op = iterator.make_initializer(repeat_dataset) self.init_batch_op = iterator.make_initializer(batch_dataset) self.get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index a19c91707541f83003ef9af1cc2a8527a8e5f6c3..0ac8d7359f7234d98167277724780bf31555e6fb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -22,11 +22,8 @@ import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.ops import string_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import device_setter from tensorflow.python.util import compat @@ -51,10 +48,8 @@ class ResampleTest(test.TestCase): seed=27)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - variable_init_op = variables.local_variables_initializer() with self.test_session() as sess: - sess.run(variable_init_op) sess.run(init_op) returned = [] with self.assertRaises(errors.OutOfRangeError): @@ -75,23 +70,6 @@ class ResampleTest(test.TestCase): returned_dist = class_counts / total_returned self.assertAllClose(target_dist, returned_dist, atol=1e-2) - def testVariableDevicePlacement(self): - classes = np.random.randint(5, size=(20000,)) # Uniformly sampled - target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] - with ops.device( - device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")): - _ = (dataset_ops.Dataset.from_tensor_slices(classes).shuffle( - 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).apply( - resampling.rejection_resample( - target_dist=target_dist, - initial_dist=None, - class_func=lambda c, _: c, - seed=27))) - - self.assertEqual(1, len(variables.local_variables())) - self.assertEqual(b"", - compat.as_bytes(variables.local_variables()[0].device)) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5338ec56bf275e481a984964e39aa0c1ade3a752 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -0,0 +1,128 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.contrib.data.python.ops import scan_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ScanDatasetTest(test.TestCase): + + def _count(self, start, step): + return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( + scan_ops.scan(start, lambda state, _: (state + step, state))) + + def testCount(self): + start = array_ops.placeholder(dtypes.int32, shape=[]) + step = array_ops.placeholder(dtypes.int32, shape=[]) + take = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = self._count(start, step).take(take).make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + + for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), + (10, 2, 10), (10, -1, 10), + (10, -2, 10)]: + sess.run(iterator.initializer, + feed_dict={start: start_val, step: step_val, take: take_val}) + for expected, _ in zip( + itertools.count(start_val, step_val), range(take_val)): + self.assertEqual(expected, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testFibonacci(self): + iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) + ).make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(1, sess.run(next_element)) + self.assertEqual(1, sess.run(next_element)) + self.assertEqual(2, sess.run(next_element)) + self.assertEqual(3, sess.run(next_element)) + self.assertEqual(5, sess.run(next_element)) + self.assertEqual(8, sess.run(next_element)) + + def testChangingStateShape(self): + # Test the fixed-point shape invariant calculations: start with + # initial values with known shapes, and use a scan function that + # changes the size of the state on each element. + def _scan_fn(state, input_value): + # Statically known rank, but dynamic length. + ret_longer_vector = array_ops.concat([state[0], state[0]], 0) + # Statically unknown rank. + ret_larger_rank = array_ops.expand_dims(state[1], 0) + return (ret_longer_vector, ret_larger_rank), (state, input_value) + + dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply( + scan_ops.scan(([0], 1), _scan_fn)) + self.assertEqual([None], dataset.output_shapes[0][0].as_list()) + self.assertIs(None, dataset.output_shapes[0][1].ndims) + self.assertEqual([], dataset.output_shapes[1].as_list()) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(5): + (longer_vector_val, larger_rank_val), _ = sess.run(next_element) + self.assertAllEqual([0] * (2**i), longer_vector_val) + self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testIncorrectStateType(self): + + def _scan_fn(state, _): + return constant_op.constant(1, dtype=dtypes.int64), state + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The element types for the new state must match the initial state."): + dataset.apply( + scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) + + def testIncorrectReturnType(self): + + def _scan_fn(unused_state, unused_input_value): + return constant_op.constant(1, dtype=dtypes.int64) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The scan function must return a pair comprising the new state and the " + "output value."): + dataset.apply( + scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index d9bfca30bbfd66afead842c2bc3020e9d4bcc2d9..e9ebaf4f21534fb43218d9579127b4aeb1dbd85e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -22,6 +22,7 @@ import collections import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -52,7 +53,7 @@ class ShuffleDatasetTest(test.TestCase): # Create initialization ops for iterators without and with # shuffling, respectively. - iterator = dataset_ops.Iterator.from_structure( + iterator = iterator_ops.Iterator.from_structure( shuffle_dataset.output_types, shuffle_dataset.output_shapes) init_fifo_op = iterator.make_initializer(repeat_dataset) init_shuffle_op = iterator.make_initializer(shuffle_dataset) diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 29cd960d9cfb4acec0487691755e51ede48d8726..2a9b41d6df0b447d64dc6cf28961e08cab5f367f 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -16,7 +16,6 @@ py_library( "//tensorflow/python:script_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator", "//tensorflow/python/data/util:nest", ], ) @@ -50,6 +49,7 @@ py_library( "error_ops.py", "grouping.py", "resampling.py", + "scan_ops.py", "sloppy_ops.py", ], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 847f974940919bd0c0c9a3e7793cf82be99e6860..abc9212a87550745490b974d25a929a66287f785 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -50,7 +50,7 @@ def dense_to_sparse_batch(batch_size, row_shape): ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices ['a', 'b', 'c', 'a', 'b'], # values [2, 6]), # dense_shape - ([[2, 0], [2, 1], [2, 2], [2, 3]], + ([[0, 0], [0, 1], [0, 2], [0, 3]], ['a', 'b', 'c', 'd'], [1, 6]) } @@ -68,7 +68,7 @@ def dense_to_sparse_batch(batch_size, row_shape): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): @@ -87,7 +87,7 @@ def unbatch(): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): @@ -106,7 +106,7 @@ def unbatch(): def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). - Like @{tf.contrib.data.Dataset.batch}, this transformation combines + Like @{tf.data.Dataset.batch}, this transformation combines consecutive elements of this dataset into batches. However, if the batch size does not evenly divide the input dataset size, this transformation will drop the final smaller element. @@ -115,7 +115,7 @@ def batch_and_drop_remainder(batch_size): transformation and `Dataset.batch()`: ```python - dataset = tf.contrib.data.Dataset.range(200) + dataset = tf.data.Dataset.range(200) batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128)) print(batched.output_shapes) # ==> "(128,)" (the batch dimension is known) ``` @@ -130,7 +130,7 @@ def batch_and_drop_remainder(batch_size): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply} + @{tf.data.Dataset.apply} """ def _apply_fn(dataset): @@ -272,3 +272,79 @@ class _RestructuredDataset(dataset_ops.Dataset): @property def output_shapes(self): return self._output_shapes + + +class _MapAndBatchDataset(dataset_ops.MapDataset): + """A `Dataset` that maps a function over a batch of elements.""" + + def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): + """See `Dataset.map()` for details.""" + super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) + + self._batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + self._num_parallel_batches = ops.convert_to_tensor( + num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + + def _as_variant_tensor(self): + # pylint: disable=protected-access + input_resource = self._input_dataset._as_variant_tensor() + return gen_dataset_ops.map_and_batch_dataset( + input_resource, + self._map_func.captured_inputs, + f=self._map_func, + batch_size=self._batch_size, + num_parallel_batches=self._num_parallel_batches, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + # pylint: enable=protected-access + + @property + def output_shapes(self): + return nest.pack_sequence_as(self._output_shapes, [ + tensor_shape.vector(tensor_util.constant_value( + self._batch_size)).concatenate(s) + for s in nest.flatten(self._output_shapes) + ]) + + @property + def output_types(self): + return self._output_types + + +def map_and_batch(map_func, batch_size, num_parallel_batches=1): + """Fused implementation of `map` and `batch`. + + Maps `map_func` across `batch_size` consecutive elements of this dataset + and then combines them into a batch. Similarly to `batch_and_drop_remainder`, + if the batch size does not evenly divide the input dataset size, this + transformation will drop the final smaller element. + + + Functionally, it is equivalent to `map` followed by + `batch_and_drop_remainder`. However, by fusing the two transformations + together, the implementation can be more efficient. This transformation is a + stop gap solution for performance critical workloads. Once automatic input + pipeline optimization are implemented, the fusing of map and batch will not + need to be exposed at the API level and this method will be removed. + + Args: + map_func: A function mapping a nested structure of tensors to another + nested structure of tensors. + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the + number of batches to create in parallel. On one hand, higher values can + help mitigate the effect of stragglers. On the other hand, higher values + can increasing contention if CPU is scarce. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _MapAndBatchDataset(dataset, map_func, batch_size, + num_parallel_batches) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 7e58caffa940984a4ee8e273a32ff1fb9235e026..45d6dbe7438957029b4d6b71e181cb1fc3596ecb 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -23,14 +23,10 @@ from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops -# pylint: disable=unused-import -from tensorflow.python.data.ops.iterator import Iterator -# pylint: enable=unused-import from tensorflow.python.data.util import nest -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_io_ops -from tensorflow.python.ops import script_ops +from tensorflow.python.util import deprecation class Dataset(dataset_ops.Dataset): @@ -45,6 +41,7 @@ class Dataset(dataset_ops.Dataset): super(Dataset, self).__init__() self._dataset = dataset + @deprecation.deprecated(None, "Use `ds._as_variant_tensor()`.") def make_dataset_resource(self): return self._as_variant_tensor() @@ -60,6 +57,7 @@ class Dataset(dataset_ops.Dataset): return self._dataset.output_types @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensors()`.") def from_tensors(tensors): """Creates a `Dataset` with a single element, comprising the given tensors. @@ -72,6 +70,7 @@ class Dataset(dataset_ops.Dataset): return Dataset(dataset_ops.TensorDataset(tensors)) @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") def from_tensor_slices(tensors): """Creates a `Dataset` whose elements are slices of the given tensors. @@ -85,6 +84,8 @@ class Dataset(dataset_ops.Dataset): return Dataset(dataset_ops.TensorSliceDataset(tensors)) @staticmethod + @deprecation.deprecated(None, + "Use `tf.data.Dataset.from_sparse_tensor_slices()`.") def from_sparse_tensor_slices(sparse_tensor): """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. @@ -97,6 +98,7 @@ class Dataset(dataset_ops.Dataset): return Dataset(dataset_ops.SparseTensorSliceDataset(sparse_tensor)) @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.from_generator()`.") def from_generator(generator, output_types, output_shapes=None): """Creates a `Dataset` whose elements are generated by `generator`. @@ -134,125 +136,11 @@ class Dataset(dataset_ops.Dataset): Returns: A `Dataset`. """ - if not callable(generator): - raise TypeError("`generator` must be callable.") - if output_shapes is None: - output_shapes = nest.map_structure( - lambda _: tensor_shape.TensorShape(None), output_types) - else: - output_shapes = nest.map_structure_up_to( - output_types, tensor_shape.as_shape, output_shapes) - - flattened_types = nest.flatten(output_types) - flattened_shapes = nest.flatten(output_shapes) - - generator_state = dataset_ops.Dataset._GeneratorState(generator) - - def get_iterator_id_map_fn(unused_dummy): - """Creates a unique `iterator_id` for each pass over the dataset. - - The "iterator_id" disambiguates between multiple concurrently - existing iterators. - - Args: - unused_dummy: Ignored value. - - Returns: - A `tf.int64` tensor whose value uniquely identifies an iterator in - `generator_state`. - """ - return script_ops.py_func( - generator_state.get_next_id, [], dtypes.int64, stateful=True) - - def generator_map_fn(iterator_id_t): - """Generates the next element from iterator with ID `iterator_id_t`. - - We map this function across an infinite repetition of the - `iterator_id_t`, and raise `StopIteration` to terminate the iteration. - - Args: - iterator_id_t: A `tf.int64` tensor whose value uniquely identifies - the iterator in `generator_state` from which to generate an element. - - Returns: - A nested structure of tensors representing an element from the iterator. - """ - - def generator_py_func(iterator_id): - """A `py_func` that will be called to invoke the iterator.""" - try: - values = next(generator_state.get_iterator(iterator_id)) - except StopIteration: - generator_state.iterator_completed(iterator_id) - raise StopIteration("Iteration finished.") - - # Use the same _convert function from the py_func() implementation to - # convert the returned values to arrays early, so that we can inspect - # their values. - # pylint: disable=protected-access - ret_arrays = [ - script_ops.FuncRegistry._convert(ret) - for ret in nest.flatten_up_to(output_types, values) - ] - # pylint: enable=protected-access - - # Additional type and shape checking to ensure that the components - # of the generated element match the `output_types` and `output_shapes` - # arguments. - for (ret_array, expected_dtype, expected_shape) in zip( - ret_arrays, flattened_types, flattened_shapes): - if ret_array.dtype != expected_dtype.as_numpy_dtype: - raise TypeError( - "`generator` yielded an element of type %s where an element " - "of type %s was expected." % (ret_array.dtype, - expected_dtype.as_numpy_dtype)) - if not expected_shape.is_compatible_with(ret_array.shape): - raise ValueError( - "`generator` yielded an element of shape %s where an element " - "of shape %s was expected." % (ret_array.shape, expected_shape)) - - return ret_arrays - - flat_values = script_ops.py_func( - generator_py_func, [iterator_id_t], flattened_types, stateful=True) - - # The `py_func()` op drops the inferred shapes, so we add them back in - # here. - if output_shapes is not None: - for ret_t, shape in zip(flat_values, flattened_shapes): - ret_t.set_shape(shape) - - return nest.pack_sequence_as(output_types, flat_values) - - # This function associates each traversal of `generator` with a unique - # iterator ID. - def flat_map_fn(iterator_id_t): - # First, generate an infinite dataset containing the iterator ID repeated - # forever. - repeated_id = Dataset.from_tensors(iterator_id_t).repeat(None) - - # The `generator_map_fn` gets the next element from the iterator with the - # relevant ID, and raises StopIteration when that iterator contains no - # more elements. - return repeated_id.map(generator_map_fn) - - # A single-element dataset that, each time it is evaluated, contains a - # freshly-generated and unique (for the returned dataset) int64 - # ID that will be used to identify the appropriate Python state, which - # is encapsulated in `generator_state`, and captured in - # `get_iterator_id_map_fn`. - dummy = 0 - id_dataset = Dataset.from_tensors(dummy).map(get_iterator_id_map_fn) - - # A dataset that contains all of the elements generated by a - # single iterator created from `generator`, identified by the - # iterator ID contained in `id_dataset`. Lifting the iteration - # into a flat_map here enables multiple repetitions and/or nested - # versions of the returned dataset to be created, because it forces - # the generation of a new ID for each version. - return id_dataset.flat_map(flat_map_fn) + return Dataset(dataset_ops.Dataset.from_generator( + generator, output_types, output_shapes)) @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.range()`.") def range(*args): """Creates a `Dataset` of a step-separated range of values. @@ -282,6 +170,7 @@ class Dataset(dataset_ops.Dataset): return Dataset(dataset_ops.RangeDataset(*args)) @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.zip()`.") def zip(datasets): """Creates a `Dataset` by zipping together the given datasets. @@ -361,6 +250,7 @@ class Dataset(dataset_ops.Dataset): return Dataset(dataset_ops.PrefetchDataset(self._dataset, buffer_size)) @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.list_files()`.") def list_files(file_pattern): """A dataset of all files matching a pattern. @@ -397,6 +287,8 @@ class Dataset(dataset_ops.Dataset): """ return Dataset(dataset_ops.RepeatDataset(self._dataset, count)) + @deprecation.deprecated( + None, "Use `ds.apply(tf.contrib.data.enumerate_dataset())`.") def enumerate(self, start=0): """Deprecated: Use `Dataset.apply(tf.contrib.data.enumerate_dataset(..)`.""" @@ -514,8 +406,10 @@ class Dataset(dataset_ops.Dataset): """ return Dataset(self._dataset.shard(num_shards, index)) + @deprecation.deprecated( + None, "Use `ds.apply(tf.contrib.data.ignore_errors())`.") def ignore_errors(self): - """Deprecated: Use `Dataset.apply(tf.contrib.data.ignore_errors()`.""" + """Deprecated: Use `Dataset.apply(tf.contrib.data.ignore_errors())`.""" return self.apply(error_ops.ignore_errors()) @@ -562,17 +456,26 @@ class Dataset(dataset_ops.Dataset): dataset_ops.PaddedBatchDataset(self._dataset, batch_size, padded_shapes, padding_values)) + @deprecation.deprecated( + None, "Use `ds.apply(tf.contrib.data.dense_to_sparse_batch())`.") def dense_to_sparse_batch(self, batch_size, row_shape): """Use: `Dataset.apply(tf.contrib.data.dense_to_sparse_batch(...))`.""" return self.apply(batching.dense_to_sparse_batch(batch_size, row_shape)) + @deprecation.deprecated( + None, "Use `ds.apply(tf.contrib.data.group_by_window())`.") def group_by_window(self, key_func, reduce_func, window_size): """Deprecated: Use `Dataset.apply(tf.contrib.data.group_by_window(...))`.""" return self.apply( grouping.group_by_window(key_func, reduce_func, window_size)) + @deprecation.deprecated_args( + None, + "Replace `num_threads=T` with `num_parallel_calls=T`. Replace " + "`output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.", + "num_threads", "output_buffer_size") def map(self, map_func, num_threads=None, @@ -694,6 +597,7 @@ class Dataset(dataset_ops.Dataset): dataset_ops.InterleaveDataset(self._dataset, map_func, cycle_length, block_length)) + @deprecation.deprecated(None, "Use `ds.apply(tf.contrib.data.unbatch())`.") def unbatch(self): """Deprecated: Use `Dataset.apply(tf.contrib.data.unbatch()`.""" @@ -738,3 +642,48 @@ class Dataset(dataset_ops.Dataset): if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`transformation_func` must return a Dataset.") return Dataset(dataset) + + +def get_single_element(dataset): + """Returns the single element in `dataset` as a nested structure of tensors. + + This function enables you to use a @{tf.data.Dataset} in a stateless + "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. + This can be useful when your preprocessing transformations are expressed + as a `Dataset`, and you want to use the transformation at serving time. + For example: + + ```python + input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) + + def preprocessing_fn(input_str): + # ... + return image, label + + dataset = (tf.data.Dataset.from_tensor_slices(input_batch) + .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) + .batch(BATCH_SIZE)) + + image_batch, label_batch = tf.contrib.data.get_single_element(dataset) + ``` + + Args: + dataset: A @{tf.data.Dataset} object containing a single element. + + Returns: + A nested structure of @{tf.Tensor} objects, corresponding to the single + element of `dataset`. + + Raises: + TypeError: if `dataset` is not a `tf.data.Dataset` object. + InvalidArgumentError (at runtime): if `dataset` does not contain exactly + one element. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + return nest.pack_sequence_as( + dataset.output_types, + gen_dataset_ops.dataset_to_single_element( + dataset._as_variant_tensor(), # pylint: disable=protected-access + output_types=nest.flatten(dataset.output_types), + output_shapes=nest.flatten(dataset.output_shapes))) diff --git a/tensorflow/contrib/data/python/ops/enumerate_ops.py b/tensorflow/contrib/data/python/ops/enumerate_ops.py index 40e7315f1f6f1fa67874589cdc7ae96382d80e28..ac2b386b81532b801139baa00fd5edd4ecd6ef0a 100644 --- a/tensorflow/contrib/data/python/ops/enumerate_ops.py +++ b/tensorflow/contrib/data/python/ops/enumerate_ops.py @@ -47,7 +47,7 @@ def enumerate_dataset(start=0): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index dffa8b7f7dc66635948a5ba52868ffea121ee164..238bb52b0205f9ab66f479f1b92e72ab6e38725b 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -30,7 +30,7 @@ def ignore_errors(): example: ```python - dataset = tf.contrib.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) + dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError. dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error")) @@ -42,7 +42,7 @@ def ignore_errors(): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 2cf7e8f4ee5fc820d6f8321695abfef53d88498e..6df7b22fb69bb14c41a26bd630a825442f67ee23 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -57,7 +57,7 @@ def group_by_window(key_func, Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. Raises: ValueError: if neither or both of {`window_size`, `window_size_func`} are diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 98b1fe4dbf527b6471852bc7e10ab8ab9f29c2ac..2e1c3153ca78e20e2628e8754b9827b817f8c732 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -28,11 +28,13 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile +from tensorflow.python.util import deprecation class TextLineDataset(contrib_dataset_ops.Dataset): """A `Dataset` comprising lines from one or more text files.""" + @deprecation.deprecated(None, "Use `tf.data.TextLineDataset`.") def __init__(self, filenames, compression_type=None, buffer_size=None): """Creates a `TextLineDataset`. @@ -52,6 +54,7 @@ class TextLineDataset(contrib_dataset_ops.Dataset): class TFRecordDataset(contrib_dataset_ops.Dataset): """A `Dataset` comprising records from one or more TFRecord files.""" + @deprecation.deprecated(None, "Use `tf.data.TFRecordDataset`.") def __init__(self, filenames, compression_type=None, buffer_size=None): """Creates a `TFRecordDataset`. @@ -70,6 +73,7 @@ class TFRecordDataset(contrib_dataset_ops.Dataset): class FixedLengthRecordDataset(contrib_dataset_ops.Dataset): """A `Dataset` of fixed-length records from one or more binary files.""" + @deprecation.deprecated(None, "Use `tf.data.FixedLengthRecordDataset`.") def __init__(self, filenames, record_bytes, diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index f4f2d4285443a5b2b5e02b3126f2edd9bd47a937..56f526a330bfbea7305b0754bfd114c5e97db506 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,7 +29,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import resource_variable_ops def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): @@ -48,7 +48,7 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): @@ -68,26 +68,20 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): num_classes = (target_dist_t.shape[0].value or array_ops.shape(target_dist_t)[0]) smoothing_constant = 10 - # Disable device functions and colocation constraints so that the variable - # will be placed with the eventual DT_VARIANT dataset tensor. - with ops.colocate_with(None, ignore_existing=True): - num_examples_per_class_seen = resource_variable_ops.ResourceVariable( - initial_value=array_ops.fill([num_classes], - np.int64(smoothing_constant)), - trainable=False, - collections=[ops.GraphKeys.LOCAL_VARIABLES], - name="local_class_count", - dtype=dtypes.int64) - - def update_estimate_and_tile(c): - return array_ops.tile( - array_ops.expand_dims( - _estimate_data_distribution(c, num_examples_per_class_seen), 0), - [dist_estimation_batch_size, 1]) + initial_examples_per_class_seen = array_ops.fill( + [num_classes], np.int64(smoothing_constant)) + + def update_estimate_and_tile(num_examples_per_class_seen, c): + updated_examples_per_class_seen, dist = _estimate_data_distribution( + c, num_examples_per_class_seen) + tiled_dist = array_ops.tile( + array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) + return updated_examples_per_class_seen, tiled_dist initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) - .map(update_estimate_and_tile).apply(batching - .unbatch())) + .apply(scan_ops.scan(initial_examples_per_class_seen, + update_estimate_and_tile)) + .apply(batching.unbatch())) acceptance_dist_ds = initial_dist_ds.map( lambda initial: _calculate_acceptance_probs(initial, target_dist_t)) @@ -174,20 +168,21 @@ def _estimate_data_distribution(c, num_examples_per_class_seen): Args: c: The class labels. Type `int32`, shape `[batch_size]`. - num_examples_per_class_seen: A `ResourceVariable` containing counts. - Type `int64`, shape `[num_classes]`. + num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, + containing counts. Returns: + num_examples_per_lass_seen: Updated counts. Type `int64`, shape + `[num_classes]`. dist: The updated distribution. Type `float32`, shape `[num_classes]`. """ num_classes = num_examples_per_class_seen.get_shape()[0].value - # Update the class-count based on what labels are seen in - # batch. But do this asynchronously to avoid performing a - # cross-device round-trip. Just use the cached value. - num_examples_per_class_seen = num_examples_per_class_seen.assign_add( - math_ops.reduce_sum( + # Update the class-count based on what labels are seen in batch. + num_examples_per_class_seen = math_ops.add( + num_examples_per_class_seen, math_ops.reduce_sum( array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) init_prob_estimate = math_ops.truediv( num_examples_per_class_seen, math_ops.reduce_sum(num_examples_per_class_seen)) - return math_ops.cast(init_prob_estimate, dtypes.float32) + dist = math_ops.cast(init_prob_estimate, dtypes.float32) + return num_examples_per_class_seen, dist diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5acaed48a3d73e93706bdd0b5b2d614b0c565ab7 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -0,0 +1,182 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Scan dataset transformation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +class _ScanDataset(dataset_ops.Dataset): + """A dataset that scans a function across its input.""" + + def __init__(self, input_dataset, initial_state, scan_func): + """See `scan()` for details.""" + super(_ScanDataset, self).__init__() + self._input_dataset = input_dataset + + with ops.name_scope("initial_state"): + self._initial_state = nest.pack_sequence_as(initial_state, [ + ops.convert_to_tensor(t, name="component_%d" % i) + for i, t in enumerate(nest.flatten(initial_state)) + ]) + + # Compute initial values for the state shapes and types based on + # the initial state. These will be refined by running + # `tf_scan_func` one or more times below. + self._state_shapes = nest.pack_sequence_as( + self._initial_state, + [t.shape for t in nest.flatten(self._initial_state)]) + self._state_types = nest.pack_sequence_as( + self._initial_state, + [t.dtype for t in nest.flatten(self._initial_state)]) + + # Will be populated by calling `tf_scan_func`. + self._output_shapes = None + self._output_types = None + + # Iteratively rerun the scan function until reaching a fixed pont on + # `self._state_shapes`. + need_to_rerun = True + while need_to_rerun: + + flat_state_shapes = nest.flatten(self._state_shapes) + flat_state_types = nest.flatten(self._state_types) + + # Create a list in which `tf_scan_func` will store the s + flat_new_state_shapes = [] + + @function.Defun( + *(flat_state_types + nest.flatten(input_dataset.output_types))) + def tf_scan_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the state and input_dataset. + for arg, shape in zip( + args, + flat_state_shapes + nest.flatten(input_dataset.output_shapes)): + arg.set_shape(shape) + + pivot = len(flat_state_shapes) + old_state = nest.pack_sequence_as(self._initial_state, args[:pivot]) + input_value = nest.pack_sequence_as(input_dataset.output_types, + args[pivot:]) + + ret = scan_func(old_state, input_value) + if not isinstance(ret, collections.Sequence) or len(ret) != 2: + raise TypeError("The scan function must return a pair comprising the " + "new state and the output value.") + new_state, output_value = ret + + flat_new_state = [ + ops.convert_to_tensor(t) for t in nest.flatten(new_state) + ] + flat_output_value = [ + ops.convert_to_tensor(t) for t in nest.flatten(output_value) + ] + + # Extract shape information from the returned values. + flat_new_state_shapes.extend([t.shape for t in flat_new_state]) + self._output_shapes = nest.pack_sequence_as( + output_value, [t.shape for t in flat_output_value]) + + # Extract and validate type information from the returned values. + for t, dtype in zip(flat_new_state, flat_state_types): + if t.dtype != dtype: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, nest.pack_sequence_as( + self._state_types, [t.dtype for t in flat_new_state]))) + self._output_types = nest.pack_sequence_as( + output_value, [t.dtype for t in flat_output_value]) + + return flat_new_state + flat_output_value + + # Use the private method that will execute `tf_scan_func` but delay + # adding it to the graph in case we need to rerun the function. + tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access + + weakened_state_shapes = [ + original.most_specific_compatible_shape(new) + for original, new in zip(flat_state_shapes, flat_new_state_shapes) + ] + + need_to_rerun = False + for original_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if original_shape.ndims is not None and ( + weakened_shape.ndims is None or + original_shape.as_list() != weakened_shape.as_list()): + need_to_rerun = True + break + + if need_to_rerun: + # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun + # `tf_scan_func`. + self._state_shapes = nest.pack_sequence_as(self._state_shapes, + weakened_state_shapes) + + self._scan_func = tf_scan_func + + def _as_variant_tensor(self): + input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + return gen_dataset_ops.scan_dataset( + input_t, + nest.flatten(self._initial_state), + self._scan_func.captured_inputs, + f=self._scan_func, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +def scan(initial_state, scan_func): + """A transformation that scans a function across an input dataset. + + This transformation is a stateful relative of @{tf.data.Dataset.map}. + In addition to mapping `scan_func` across the elements of the input dataset, + `scan()` accumulates one or more state tensors, whose initial values are + `initial_state`. + + Args: + initial_state: A nested structure of tensors, representing the initial state + of the accumulator. + scan_func: A function that maps `(old_state, input_element)` to + `(new_state, output_element). It must take two arguments and return a + pair of nested structures of tensors. The `new_state` must match the + structure of `initial_state`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _ScanDataset(dataset, initial_state, scan_func) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/sloppy_ops.py b/tensorflow/contrib/data/python/ops/sloppy_ops.py index 01e234f1d0db27277e9a38e6a259b4b064b89eaa..4f3da4320cd7d550c5d93db7534ad9950401a8c6 100644 --- a/tensorflow/contrib/data/python/ops/sloppy_ops.py +++ b/tensorflow/contrib/data/python/ops/sloppy_ops.py @@ -102,6 +102,17 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): strictly obeys), producing an element from a different underlying dataset instead. + Example usage: + + ```python + # Preprocess 4 files concurrently. + filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") + dataset = filenames.apply( + tf.contrib.data.sloppy_interleave( + lambda filename: tf.data.TFRecordDataset(filename), + cycle_length=4)) + ``` + WARNING: The order of elements in the resulting dataset is not deterministic. Use `Dataset.interleave()` if you want the elements to have a deterministic order. @@ -118,7 +129,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): return SloppyInterleaveDataset( diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index b86f5768ca55c0e18685a72e0d502151141273de..825ec652d0ac7750bff2a2428dfb947beee5347e 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -305,6 +305,8 @@ cuda_py_test( additional_deps = [ ":distributions_py", "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:client_testlib", ], ) @@ -350,6 +352,20 @@ cuda_py_test( ], ) +cuda_py_test( + name = "sinh_arcsinh_test", + size = "small", + srcs = ["python/kernel_tests/sinh_arcsinh_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "independent_test", size = "small", @@ -666,6 +682,24 @@ cuda_py_test( ], ) +cuda_py_test( + name = "absolute_value_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/absolute_value_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "affine_test", size = "large", @@ -763,6 +797,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "gumbel_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/gumbel_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "inline_test", size = "small", @@ -801,6 +854,22 @@ cuda_py_test( ], ) +cuda_py_test( + name = "permute_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/permute_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "power_transform_test", size = "small", @@ -858,10 +927,12 @@ cuda_py_test( ], ) +# Tests for SinhArcSinh bijector. The file name has the extra "_bijector" to +# avoid BUILD rule name conflicts with the distribution by the same name. cuda_py_test( - name = "sinh_arcsinh_test", + name = "sinh_arcsinh_bijector_test", size = "small", - srcs = ["python/kernel_tests/bijectors/sinh_arcsinh_test.py"], + srcs = ["python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index df76c7084f8c261b1f3fb7fb922de85635febdd1..f33cc1de0abc82a3a8974dba4459a55fb4c2e82c 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -51,6 +51,7 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import * from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * from tensorflow.contrib.distributions.python.ops.sample_stats import * +from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.test_util import * from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import * from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import * @@ -83,19 +84,6 @@ _allowed_symbols = [ 'ConditionalTransformedDistribution', 'FULLY_REPARAMETERIZED', 'NOT_REPARAMETERIZED', - 'Affine', - 'AffineLinearOperator', - 'Bijector', - 'Chain', - 'CholeskyOuterProduct', - 'Exp', - 'Identity', - 'Inline', - 'Invert', - 'PowerTransform', - 'SigmoidCentered', - 'SoftmaxCentered', - 'Softplus', 'ReparameterizationType', 'Distribution', 'Binomial', @@ -125,6 +113,7 @@ _allowed_symbols = [ 'NormalWithSoftplusScale', 'Poisson', 'PoissonLogNormalQuadratureCompound', + 'SinhArcsinh', 'StudentT', 'StudentTWithAbsDfSoftplusScale', 'Uniform', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d65c79b2654c2949de161d6317f218d11cab43 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py @@ -0,0 +1,85 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for AbsoluteValue Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +# pylint: disable=g-importing-member +from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import AbsoluteValue +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + +# pylint: enable=g-importing-member + + +class AbsoluteValueTest(test.TestCase): + """Tests correctness of the absolute value bijector.""" + + def testBijectorVersusNumpyRewriteOfBasicFunctionsEventNdims0(self): + with self.test_session() as sess: + bijector = AbsoluteValue(event_ndims=0, validate_args=True) + self.assertEqual("absolute_value", bijector.name) + x = array_ops.constant([[0., 1., -1], [0., -5., 3.]]) # Shape [2, 3] + y = math_ops.abs(x) + + y_ = y.eval() + zeros = np.zeros((2, 3)) + + self.assertAllClose(y_, bijector.forward(x).eval()) + self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y))) + self.assertAllClose((zeros, zeros), + sess.run(bijector.inverse_log_det_jacobian(y))) + + # Run things twice to make sure there are no issues in caching the tuples + # returned by .inverse* + self.assertAllClose(y_, bijector.forward(x).eval()) + self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y))) + self.assertAllClose((zeros, zeros), + sess.run(bijector.inverse_log_det_jacobian(y))) + + def testEventNdimsMustBeZeroOrRaiseStatic(self): + with self.test_session(): + with self.assertRaisesRegexp(ValueError, "event_ndims.*was not 0"): + AbsoluteValue(event_ndims=1) + + def testEventNdimsMustBeZeroOrRaiseDynamic(self): + with self.test_session() as sess: + event_ndims = array_ops.placeholder(dtypes.int32) + abs_bijector = AbsoluteValue(event_ndims=event_ndims, validate_args=True) + with self.assertRaisesOpError("event_ndims was not 0"): + sess.run(abs_bijector.inverse_log_det_jacobian([1.]), + feed_dict={event_ndims: 1}) + + def testNegativeYRaisesForInverseIfValidateArgs(self): + with self.test_session() as sess: + bijector = AbsoluteValue(event_ndims=0, validate_args=True) + with self.assertRaisesOpError("y was negative"): + sess.run(bijector.inverse(-1.)) + + def testNegativeYRaisesForILDJIfValidateArgs(self): + with self.test_session() as sess: + bijector = AbsoluteValue(event_ndims=0, validate_args=True) + with self.assertRaisesOpError("y was negative"): + sess.run(bijector.inverse_log_det_jacobian(-1.)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py index 0738754b217e5842bd0fa516915f14926083d321..405ddd292cacd8ace87d6caeebf3e8cfc347c22d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py @@ -72,7 +72,7 @@ class AffineLinearOperatorTest(test.TestCase): [3, -2, 0], [4, 3, 2]]], dtype=np.float32) - scale = linalg.LinearOperatorTriL(tril, is_non_singular=True) + scale = linalg.LinearOperatorLowerTriangular(tril, is_non_singular=True) affine = AffineLinearOperator( shift=shift, scale=scale, validate_args=True) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py index 2c4b8277d01c7a2929fdde7babf809f2c16f730b..c9158117f7a982e37047e8dd2b534a30040a87d9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py @@ -76,7 +76,7 @@ class AffineBijectorTest(test.TestCase): for run in (static_run, dynamic_run): mu = -1. # Corresponds to scale = 2 - bijector = Affine(shift=mu, scale_diag=[2.], event_ndims=0) + bijector = Affine(shift=mu, scale_identity_multiplier=2., event_ndims=0) self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) @@ -84,7 +84,7 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose(-np.log(2.), run(bijector.inverse_log_det_jacobian, x)) - def testWeirdSampleNoBatchScalarViaIdentity(self): + def testWeirdSampleNoBatchScalarViaDiagMultiplier(self): with self.test_session() as sess: def static_run(fun, x): @@ -156,7 +156,7 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([np.log(0.5)], run(bijector.inverse_log_det_jacobian, x)) - def testOneBatchScalarViaDiag(self): + def testOneBatchScalarViaDiagMultiplier(self): with self.test_session() as sess: def static_run(fun, x): @@ -171,7 +171,7 @@ class AffineBijectorTest(test.TestCase): mu = [1.] # One batch, scalar. # Corresponds to scale = 1. - bijector = Affine(shift=mu, scale_diag=[1.], event_ndims=0) + bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0) self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1.] # One sample from one batches. self.assertAllClose([2.], run(bijector.forward, x)) @@ -200,7 +200,7 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([0., 2], run(bijector.inverse, x)) self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - def testTwoBatchScalarIdentityViaDiag(self): + def testTwoBatchScalarIdentityViaDiagMultiplier(self): with self.test_session() as sess: def static_run(fun, x): @@ -215,7 +215,7 @@ class AffineBijectorTest(test.TestCase): mu = [1., -1] # Univariate, two batches. # Corresponds to scale = 1. - bijector = Affine(shift=mu, scale_diag=[1.], event_ndims=0) + bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0) self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1., 1] # One sample from each of two batches. self.assertAllClose([2., 0], run(bijector.forward, x)) @@ -410,13 +410,13 @@ class AffineBijectorTest(test.TestCase): bijector = Affine( shift=mu, scale_identity_multiplier=1., - scale_diag=[1.], - event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is vector" + scale_diag=[1., 1., 1.], + event_ndims=1) + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), + self.assertAllClose(-np.log(2.**3), run(bijector.inverse_log_det_jacobian, x)) def testIdentityWithTriL(self): @@ -668,11 +668,10 @@ class AffineBijectorTest(test.TestCase): with self.assertRaisesOpError("identity_multiplier should be non-zero"): bijector.forward(1.).eval() - # Check Diag matrix with zero scaling. - bijector = Affine( - shift=mu, scale_diag=[0.0], event_ndims=0, validate_args=True) - with self.assertRaisesOpError("diagonal part must be non-zero"): - bijector.forward(1.).eval() + def testScaleDiagAndEventNdimsZeroRaises(self): + # Check Diag matrix with zero scaling. + with self.assertRaisesRegexp(ValueError, "only scale argument"): + Affine(shift=None, scale_diag=[0.0], event_ndims=0, validate_args=True) def testScalarCongruency(self): with self.test_session(): @@ -830,6 +829,15 @@ class AffineBijectorTest(test.TestCase): x=np.array( [1., 2], dtype=np.float32)) + def testScalarEventIdentityScale(self): + with self.test_session() as sess: + doubler = Affine( + scale_identity_multiplier=2., + event_ndims=0) + doubler2 = doubler.inverse_log_det_jacobian(2.) + doubler2_ildj_ = sess.run([doubler2]) + self.assertAllClose([-np.log(2.)], doubler2_ildj_) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9a905980c7581a86bbcda8c6c726da57c09fe4f8 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py @@ -0,0 +1,70 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats + +from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import Gumbel +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class GumbelBijectorTest(test.TestCase): + """Tests correctness of the Gumbel bijector.""" + + def testBijector(self): + with self.test_session(): + loc = 0.3 + scale = 5. + bijector = Gumbel(loc=loc, scale=scale, event_ndims=1, validate_args=True) + self.assertEqual("gumbel", bijector.name) + x = np.array([[[-3.], [0.], [0.5], [4.2], [12.]]], dtype=np.float32) + # Gumbel distribution + gumbel_dist = stats.gumbel_r(loc=loc, scale=scale) + y = gumbel_dist.cdf(x).astype(np.float32) + self.assertAllClose(y, bijector.forward(x).eval()) + self.assertAllClose(x, bijector.inverse(y).eval()) + self.assertAllClose( + # We should lose a dimension from calculating the determinant of the + # jacobian. + np.squeeze(gumbel_dist.logpdf(x), axis=2), + bijector.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(y).eval(), + bijector.forward_log_det_jacobian(x).eval(), + rtol=1e-4, + atol=0.) + + def testScalarCongruency(self): + with self.test_session(): + assert_scalar_congruency( + Gumbel(loc=0.3, scale=20.), lower_x=1., upper_x=100., rtol=0.02) + + def testBijectiveAndFinite(self): + with self.test_session(): + bijector = Gumbel(loc=0., scale=3.0, event_ndims=0, validate_args=True) + x = np.linspace(-10., 10., num=10).astype(np.float32) + y = np.linspace(0.01, 0.99, num=10).astype(np.float32) + assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py new file mode 100644 index 0000000000000000000000000000000000000000..54590de373441c32cc3214cb04d45cfc2d1807ed --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py @@ -0,0 +1,87 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Permute bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.permute import Permute +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.platform import test + + +class PermuteBijectorTest(test.TestCase): + """Tests correctness of the Permute bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testBijector(self): + expected_permutation = np.int32([2, 0, 1]) + expected_x = np.random.randn(4, 2, 3) + expected_y = expected_x[..., expected_permutation] + + with self.test_session() as sess: + permutation_ph = array_ops.placeholder(dtype=dtypes.int32) + bijector = Permute( + permutation=permutation_ph, + validate_args=True) + [ + permutation_, + x_, + y_, + fldj, + ildj, + ] = sess.run([ + bijector.permutation, + bijector.inverse(expected_y), + bijector.forward(expected_x), + bijector.forward_log_det_jacobian(expected_x), + bijector.inverse_log_det_jacobian(expected_y), + ], feed_dict={permutation_ph: expected_permutation}) + self.assertEqual("permute", bijector.name) + self.assertAllEqual(expected_permutation, permutation_) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + self.assertAllClose(0., fldj, rtol=1e-6, atol=0) + self.assertAllClose(0., ildj, rtol=1e-6, atol=0) + + def testRaisesOpError(self): + with self.test_session() as sess: + with self.assertRaisesOpError("Permutation over `d` must contain"): + permutation_ph = array_ops.placeholder(dtype=dtypes.int32) + bijector = Permute( + permutation=permutation_ph, + validate_args=True) + sess.run(bijector.inverse([1.]), + feed_dict={permutation_ph: [1, 2]}) + + def testBijectiveAndFinite(self): + permutation = np.int32([2, 0, 1]) + x = np.random.randn(4, 2, 3) + y = x[..., permutation] + with self.test_session(): + bijector = Permute( + permutation=permutation, + validate_args=True) + assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py similarity index 96% rename from tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py rename to tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 230dd93a2a807cc14394e3c747c208c1f95b194d..172c180a44229089f06f250a872bc47a89991cf0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -41,7 +41,7 @@ class SinhArcsinhBijectorTest(test.TestCase): tailweight=tailweight, event_ndims=1, validate_args=True) - self.assertEqual("sinh_arcsinh", bijector.name) + self.assertEqual("SinhArcsinh", bijector.name) x = np.array([[[-2.01], [2.], [1e-4]]]).astype(np.float32) y = np.sinh((np.arcsinh(x) + skewness) * tailweight) self.assertAllClose(y, bijector.forward(x).eval()) @@ -170,6 +170,12 @@ class SinhArcsinhBijectorTest(test.TestCase): with self.assertRaisesOpError("not positive"): SinhArcsinh(tailweight=0., validate_args=True).forward(1.0).eval() + def testDefaultDtypeIsFloat32(self): + with self.test_session(): + bijector = SinhArcsinh() + self.assertEqual(bijector.tailweight.dtype, np.float32) + self.assertEqual(bijector.skewness.dtype, np.float32) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index cc7d6fd5ddda8fcdfdf6c8a3f80feeda7a42541e..2d74aa1f320149d0f7ef9e9c52b8c7053c2f74d7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -23,11 +23,11 @@ import itertools import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.contrib.linalg.python.ops import linear_operator_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops.linalg import linear_operator_diag import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -287,6 +287,26 @@ class ShapesFromLocAndScaleTest(test.TestCase): self.assertAllEqual([3], event_shape) +class GetBroadcastShapeTest(test.TestCase): + + def test_all_static_shapes_work(self): + x = array_ops.ones((2, 1, 3)) + y = array_ops.ones((1, 5, 3)) + z = array_ops.ones(()) + self.assertAllEqual([2, 5, 3], + distribution_util.get_broadcast_shape(x, y, z)) + + def test_with_some_dynamic_shapes_works(self): + x = array_ops.ones((2, 1, 3)) + y = array_ops.placeholder(x.dtype) + z = array_ops.ones(()) + with self.test_session() as sess: + bcast_shape = sess.run( + distribution_util.get_broadcast_shape(x, y, z), + feed_dict={y: np.ones((1, 5, 3)).astype(np.float32)}) + self.assertAllEqual([2, 5, 3], bcast_shape) + + class TridiagTest(test.TestCase): def testWorksCorrectlyNoBatches(self): @@ -374,5 +394,6 @@ class MixtureStddevTest(test.TestCase): self.assertAllClose(actual_devs, expected_devs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py index 47ac412500d36df999225d94be0ecb7cccf75723..ee4f989dac0761f04b1b6bc88f7de598f194634e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py @@ -23,67 +23,75 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import mixture_same_family as mixture_same_family_lib from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test -class MixtureSameFamilyTest( - test_util.VectorDistributionTestHelpers, test.TestCase): +class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers, + test.TestCase): def testSampleAndLogProbUnivariateShapes(self): with self.test_session(): gm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=[0.3, 0.7]), + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=normal_lib.Normal( - loc=[-1., 1], - scale=[0.1, 0.5])) - x = gm.sample([4, 5]) + loc=[-1., 1], scale=[0.1, 0.5])) + x = gm.sample([4, 5], seed=42) log_prob_x = gm.log_prob(x) self.assertEqual([4, 5], x.shape) self.assertEqual([4, 5], log_prob_x.shape) def testSampleAndLogProbShapesBroadcastMix(self): mix_probs = np.float32([.3, .7]) - bern_probs = np.float32([[.4, .6], - [.25, .75]]) + bern_probs = np.float32([[.4, .6], [.25, .75]]) with self.test_session(): bm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=mix_probs), - components_distribution=bernoulli_lib.Bernoulli( - probs=bern_probs)) - x = bm.sample([4, 5]) + mixture_distribution=categorical_lib.Categorical(probs=mix_probs), + components_distribution=bernoulli_lib.Bernoulli(probs=bern_probs)) + x = bm.sample([4, 5], seed=42) log_prob_x = bm.log_prob(x) x_ = x.eval() self.assertEqual([4, 5, 2], x.shape) self.assertEqual([4, 5, 2], log_prob_x.shape) - self.assertAllEqual(np.ones_like(x_, dtype=np.bool), - np.logical_or(x_ == 0., x_ == 1.)) + self.assertAllEqual( + np.ones_like(x_, dtype=np.bool), np.logical_or(x_ == 0., x_ == 1.)) def testSampleAndLogProbMultivariateShapes(self): with self.test_session(): gm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=[0.3, 0.7]), + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( - loc=[[-1., 1], [1, -1]], - scale_identity_multiplier=[1., 0.5])) - x = gm.sample([4, 5]) + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) + x = gm.sample([4, 5], seed=42) log_prob_x = gm.log_prob(x) self.assertEqual([4, 5, 2], x.shape) self.assertEqual([4, 5], log_prob_x.shape) + def testSampleAndLogProbBatchMultivariateShapes(self): + with self.test_session(): + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), + components_distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=[[[-1., 1], + [1, -1]], + [[0., 1], + [1, 0]]], + scale_identity_multiplier=[1., 0.5])) + x = gm.sample([4, 5], seed=42) + log_prob_x = gm.log_prob(x) + self.assertEqual([4, 5, 2, 2], x.shape) + self.assertEqual([4, 5, 2], log_prob_x.shape) + def testSampleConsistentLogProb(self): with self.test_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=[0.3, 0.7]), + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( - loc=[[-1., 1], [1, -1]], - scale_identity_multiplier=[1., 0.5])) + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( sess, gm, radius=1., center=[-1., 1], rtol=0.02) @@ -91,26 +99,40 @@ class MixtureSameFamilyTest( self.run_test_sample_consistent_log_prob( sess, gm, radius=1., center=[1., -1], rtol=0.02) + def testLogCdf(self): + with self.test_session() as sess: + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), + components_distribution=normal_lib.Normal( + loc=[-1., 1], scale=[0.1, 0.5])) + x = gm.sample(10, seed=42) + actual_log_cdf = gm.log_cdf(x) + expected_log_cdf = math_ops.reduce_logsumexp( + (gm.mixture_distribution.logits + + gm.components_distribution.log_cdf(x[..., array_ops.newaxis])), + axis=1) + actual_log_cdf_, expected_log_cdf_ = sess.run([ + actual_log_cdf, expected_log_cdf]) + self.assertAllClose(actual_log_cdf_, expected_log_cdf_, + rtol=1e-6, atol=0.0) + def testSampleConsistentMeanCovariance(self): with self.test_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=[0.3, 0.7]), + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( - loc=[[-1., 1], [1, -1]], - scale_identity_multiplier=[1., 0.5])) + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) self.run_test_sample_consistent_mean_covariance(sess, gm) def testVarianceConsistentCovariance(self): with self.test_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( - mixture_distribution=categorical_lib.Categorical( - probs=[0.3, 0.7]), + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( - loc=[[-1., 1], [1, -1]], - scale_identity_multiplier=[1., 0.5])) + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) cov_, var_ = sess.run([gm.covariance(), gm.variance()]) self.assertAllClose(cov_.diagonal(), var_, atol=0.) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py index 61c2185e86b2bcf0fab22abfa6f305e27e3f1459..1e514fe0ff21cd53c8c235da417890773db50c37 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py @@ -38,7 +38,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging -distributions_py = distributions +ds = distributions def _swap_first_last_axes(array): @@ -74,7 +74,7 @@ def _test_capture_mvndiag_sample_outputs(): """Use monkey-patching to capture the output of an MVNDiag _call_sample_n.""" data_container = [] true_mvndiag_call_sample_n = ( - distributions_py.MultivariateNormalDiag._call_sample_n) + ds.MultivariateNormalDiag._call_sample_n) def _capturing_mvndiag_call_sample_n( self, sample_shape, seed, name, **kwargs): @@ -83,10 +83,10 @@ def _test_capture_mvndiag_sample_outputs(): data_container.append(samples) return samples - distributions_py.MultivariateNormalDiag._call_sample_n = ( + ds.MultivariateNormalDiag._call_sample_n = ( _capturing_mvndiag_call_sample_n) yield data_container - distributions_py.MultivariateNormalDiag._call_sample_n = ( + ds.MultivariateNormalDiag._call_sample_n = ( true_mvndiag_call_sample_n) @@ -94,7 +94,7 @@ def _test_capture_mvndiag_sample_outputs(): def _test_capture_normal_sample_outputs(): """Use monkey-patching to capture the output of an Normal _call_sample_n.""" data_container = [] - true_normal_call_sample_n = distributions_py.Normal._call_sample_n + true_normal_call_sample_n = ds.Normal._call_sample_n def _capturing_normal_call_sample_n(self, sample_shape, seed, name, **kwargs): samples = true_normal_call_sample_n( @@ -102,9 +102,9 @@ def _test_capture_normal_sample_outputs(): data_container.append(samples) return samples - distributions_py.Normal._call_sample_n = _capturing_normal_call_sample_n + ds.Normal._call_sample_n = _capturing_normal_call_sample_n yield data_container - distributions_py.Normal._call_sample_n = true_normal_call_sample_n + ds.Normal._call_sample_n = true_normal_call_sample_n def make_univariate_mixture(batch_shape, num_components): @@ -113,13 +113,13 @@ def make_univariate_mixture(batch_shape, num_components): array_ops.concat((batch_shape, [num_components]), axis=0), -1, 1, dtype=dtypes.float32) - 50. components = [ - distributions_py.Normal( + ds.Normal( loc=random_ops.random_normal(batch_shape), scale=10 * random_ops.random_uniform(batch_shape)) for _ in range(num_components) ] - cat = distributions_py.Categorical(logits, dtype=dtypes.int32) - return distributions_py.Mixture(cat, components) + cat = ds.Categorical(logits, dtype=dtypes.int32) + return ds.Mixture(cat, components) def make_multivariate_mixture(batch_shape, num_components, event_shape, @@ -141,11 +141,11 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape, scale_diag = 10 * random_ops.random_uniform(batch_and_event_shape) loc.set_shape(static_batch_and_event_shape) scale_diag.set_shape(static_batch_and_event_shape) - return distributions_py.MultivariateNormalDiag( + return ds.MultivariateNormalDiag( loc=loc, scale_diag=scale_diag) components = [create_component() for _ in range(num_components)] - cat = distributions_py.Categorical(logits, dtype=dtypes.int32) - return distributions_py.Mixture(cat, components) + cat = ds.Categorical(logits, dtype=dtypes.int32) + return ds.Mixture(cat, components) class MixtureTest(test.TestCase): @@ -170,37 +170,37 @@ class MixtureTest(test.TestCase): def testBrokenShapesStatic(self): with self.assertRaisesWithPredicateMatch(ValueError, r"cat.num_classes != len"): - distributions_py.Mixture( - distributions_py.Categorical([0.1, 0.5]), # 2 classes - [distributions_py.Normal(loc=1.0, scale=2.0)]) + ds.Mixture( + ds.Categorical([0.1, 0.5]), # 2 classes + [ds.Normal(loc=1.0, scale=2.0)]) with self.assertRaisesWithPredicateMatch( ValueError, r"\(\) and \(2,\) are not compatible"): # The value error is raised because the batch shapes of the # Normals are not equal. One is a scalar, the other is a # vector of size (2,). - distributions_py.Mixture( - distributions_py.Categorical([-0.5, 0.5]), # scalar batch + ds.Mixture( + ds.Categorical([-0.5, 0.5]), # scalar batch [ - distributions_py.Normal( + ds.Normal( loc=1.0, scale=2.0), # scalar dist - distributions_py.Normal( + ds.Normal( loc=[1.0, 1.0], scale=[2.0, 2.0]) ]) with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"): cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32) - distributions_py.Mixture( - distributions_py.Categorical(cat_logits), - [distributions_py.Normal( + ds.Mixture( + ds.Categorical(cat_logits), + [ds.Normal( loc=[1.0], scale=[2.0])]) def testBrokenShapesDynamic(self): with self.test_session(): d0_param = array_ops.placeholder(dtype=dtypes.float32) d1_param = array_ops.placeholder(dtype=dtypes.float32) - d = distributions_py.Mixture( - distributions_py.Categorical([0.1, 0.2]), [ - distributions_py.Normal( - loc=d0_param, scale=d0_param), distributions_py.Normal( + d = ds.Mixture( + ds.Categorical([0.1, 0.2]), [ + ds.Normal( + loc=d0_param, scale=d0_param), ds.Normal( loc=d1_param, scale=d1_param) ], validate_args=True) @@ -211,21 +211,21 @@ class MixtureTest(test.TestCase): def testBrokenTypes(self): with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"): - distributions_py.Mixture(None, []) - cat = distributions_py.Categorical([0.3, 0.2]) + ds.Mixture(None, []) + cat = ds.Categorical([0.3, 0.2]) # components must be a list of distributions with self.assertRaisesWithPredicateMatch( TypeError, "all .* must be Distribution instances"): - distributions_py.Mixture(cat, [None]) + ds.Mixture(cat, [None]) with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"): - distributions_py.Mixture( + ds.Mixture( cat, [ - distributions_py.Normal(loc=[1.0], scale=[2.0]), - distributions_py.Normal(loc=[np.float16(1.0)], - scale=[np.float16(2.0)]), + ds.Normal(loc=[1.0], scale=[2.0]), + ds.Normal(loc=[np.float16(1.0)], + scale=[np.float16(2.0)]), ]) with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"): - distributions_py.Mixture(distributions_py.Categorical([0.3, 0.2]), None) + ds.Mixture(ds.Categorical([0.3, 0.2]), None) # TODO(ebrevdo): once distribution Domains have been added, add a # test to ensure that the domains of the distributions in a @@ -364,13 +364,13 @@ class MixtureTest(test.TestCase): component_devs = np.array([0.05, 2.33]) ground_truth_stddev = 5.3120805 - mixture_dist = distributions_py.Mixture( - cat=distributions_py.Categorical(probs=cat_probs), + mixture_dist = ds.Mixture( + cat=ds.Categorical(probs=cat_probs), components=[ - distributions_py.Normal(loc=component_means[0], - scale=component_devs[0]), - distributions_py.Normal(loc=component_means[1], - scale=component_devs[1]), + ds.Normal(loc=component_means[0], + scale=component_devs[0]), + ds.Normal(loc=component_means[1], + scale=component_devs[1]), ]) mix_dev = mixture_dist.stddev() with self.test_session() as sess: @@ -517,22 +517,22 @@ class MixtureTest(test.TestCase): random_seed.set_random_seed(654321) components = [ - distributions_py.Normal( + ds.Normal( loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas) ] - cat = distributions_py.Categorical( + cat = ds.Categorical( logits, dtype=dtypes.int32, name="cat1") - dist1 = distributions_py.Mixture(cat, components, name="mixture1") + dist1 = ds.Mixture(cat, components, name="mixture1") samples1 = dist1.sample(n, seed=123456).eval() random_seed.set_random_seed(654321) components2 = [ - distributions_py.Normal( + ds.Normal( loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas) ] - cat2 = distributions_py.Categorical( + cat2 = ds.Categorical( logits, dtype=dtypes.int32, name="cat2") - dist2 = distributions_py.Mixture(cat2, components2, name="mixture2") + dist2 = ds.Mixture(cat2, components2, name="mixture2") samples2 = dist2.sample(n, seed=123456).eval() self.assertAllClose(samples1, samples2) @@ -665,15 +665,15 @@ class MixtureTest(test.TestCase): e_x = np.exp(x - np.max(x)) return e_x / e_x.sum() - # Construct the distributions_py.Mixture object. + # Construct the ds.Mixture object. mixture_weights = _scalar_univariate_softmax(mixture_weight_logits) means = [np.random.uniform(low=-10, high=10, size=()).astype(np.float32) for _ in range(n_components)] sigmas = [np.ones(shape=(), dtype=np.float32) for _ in range(n_components)] - cat_tf = distributions_py.Categorical(probs=mixture_weights) - components_tf = [distributions_py.Normal(loc=mu, scale=sigma) + cat_tf = ds.Categorical(probs=mixture_weights) + components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = distributions_py.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) x_tensor = array_ops.placeholder(shape=(), dtype=dtypes.float32) @@ -718,10 +718,10 @@ class MixtureTest(test.TestCase): for _ in range(n_components)] sigmas = [np.ones(shape=psize, dtype=np.float32) for _ in range(n_components)] - cat_tf = distributions_py.Categorical(probs=mixture_weights) - components_tf = [distributions_py.Normal(loc=mu, scale=sigma) + cat_tf = ds.Categorical(probs=mixture_weights) + components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = distributions_py.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) x_tensor = array_ops.placeholder(shape=psize, dtype=dtypes.float32) xs_to_check = [ @@ -750,6 +750,20 @@ class MixtureTest(test.TestCase): self.assertAllClose(x_cdf_tf_result, scipy_cdf_result) self.assertAllClose(np.exp(x_log_cdf_tf_result), scipy_cdf_result) + def testSampleBimixGamma(self): + """Tests a bug in the underlying tf.Gamma op. + + Mixture's use of dynamic partition requires `random_gamma` correctly returns + an empty `Tensor`. + """ + with self.test_session(): + gm = ds.Mixture( + cat=ds.Categorical(probs=[.3, .7]), + components=[ds.Gamma(1., 2.), + ds.Gamma(2., 1.)]) + x_ = gm.sample().eval() + self.assertAllEqual([], x_.shape) + class MixtureBenchmark(test.Benchmark): @@ -784,7 +798,7 @@ class MixtureBenchmark(test.Benchmark): 2, "mvn_diag\tuse_gpu\tcomponents\tbatch\tfeatures\tsample\twall_time") def create_distribution(batch_size, num_components, num_features): - cat = distributions_py.Categorical( + cat = ds.Categorical( logits=np.random.randn(batch_size, num_components)) mus = [ variables.Variable(np.random.randn(batch_size, num_features)) @@ -795,9 +809,9 @@ class MixtureBenchmark(test.Benchmark): for _ in range(num_components) ] components = list( - distributions_py.MultivariateNormalDiag( + ds.MultivariateNormalDiag( loc=mu, scale_diag=sigma) for (mu, sigma) in zip(mus, sigmas)) - return distributions_py.Mixture(cat, components) + return ds.Mixture(cat, components) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): @@ -824,7 +838,7 @@ class MixtureBenchmark(test.Benchmark): return np.stack([np.dot(np.transpose(z), z) for z in x]) def create_distribution(batch_size, num_components, num_features): - cat = distributions_py.Categorical( + cat = ds.Categorical( logits=np.random.randn(batch_size, num_components)) mus = [ variables.Variable(np.random.randn(batch_size, num_features)) @@ -836,10 +850,10 @@ class MixtureBenchmark(test.Benchmark): for _ in range(num_components) ] components = list( - distributions_py.MultivariateNormalTriL( + ds.MultivariateNormalTriL( loc=mu, scale_tril=linalg_ops.cholesky(sigma)) for (mu, sigma) in zip(mus, sigmas)) - return distributions_py.Mixture(cat, components) + return ds.Mixture(cat, components) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py new file mode 100644 index 0000000000000000000000000000000000000000..88b48736dd55270fb4e149ae1560911179e446e9 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py @@ -0,0 +1,221 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SinhArcsinh.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib import distributions +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + +ds = distributions +rng = np.random.RandomState(123) + + +class SinhArcsinhTest(test.TestCase): + + def test_default_is_same_as_normal(self): + b = 10 + scale = rng.rand(b) + 0.5 + loc = rng.randn(b) + with self.test_session() as sess: + norm = ds.Normal( + loc=loc, + scale=scale, + validate_args=True) + sasnorm = ds.SinhArcsinh( + loc=loc, + scale=scale, + validate_args=True) + + x = rng.randn(5, b) + norm_pdf, sasnorm_pdf = sess.run([norm.prob(x), sasnorm.prob(x)]) + self.assertAllClose(norm_pdf, sasnorm_pdf) + + norm_samps, sasnorm_samps = sess.run( + [norm.sample(10000, seed=0), + sasnorm.sample(10000, seed=0)]) + self.assertAllClose(loc, sasnorm_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + norm_samps.mean(axis=0), sasnorm_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + norm_samps.std(axis=0), sasnorm_samps.std(axis=0), atol=0.1) + + def test_broadcast_params_dynamic(self): + with self.test_session() as sess: + loc = array_ops.placeholder(dtypes.float64) + scale = array_ops.placeholder(dtypes.float64) + skewness = array_ops.placeholder(dtypes.float64) + sasnorm = ds.SinhArcsinh( + loc=loc, + scale=scale, + skewness=skewness, + validate_args=True) + + samp = sess.run(sasnorm.sample(), + feed_dict={loc: rng.rand(5), + scale: np.float64(rng.rand()), # Scalar + skewness: rng.rand(5)}) + self.assertAllEqual((5,), samp.shape) + + def test_passing_in_laplace_plus_defaults_is_same_as_laplace(self): + b = 10 + scale = rng.rand(b) + 0.5 + loc = rng.randn(b) + with self.test_session() as sess: + lap = ds.Laplace( + loc=loc, + scale=scale, + validate_args=True) + saslap = ds.SinhArcsinh( + loc=loc, + scale=scale, + distribution=ds.Laplace(np.float64(0), np.float64(1)), + validate_args=True) + + x = rng.randn(5, b) + lap_pdf, saslap_pdf = sess.run([lap.prob(x), saslap.prob(x)]) + self.assertAllClose(lap_pdf, saslap_pdf) + + lap_samps, saslap_samps = sess.run( + [lap.sample(10000, seed=0), + saslap.sample(10000, seed=0)]) + self.assertAllClose(loc, saslap_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + lap_samps.mean(axis=0), saslap_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + lap_samps.std(axis=0), saslap_samps.std(axis=0), atol=0.1) + + def test_tailweight_small_gives_fewer_outliers_than_normal(self): + batch_size = 10 + scale = rng.rand(batch_size) + 0.5 + loc = 0.1 * rng.randn(batch_size) + with self.test_session() as sess: + norm = ds.Normal( + loc=loc, + scale=scale, + validate_args=True) + sasnorm = ds.SinhArcsinh( + loc=loc, + scale=scale, + tailweight=0.1, + validate_args=True) + + # sasnorm.pdf(x) is smaller on outliers (+-10 are outliers) + x = np.float64([[-10] * batch_size, [10] * batch_size]) # Shape [2, 10] + norm_lp, sasnorm_lp = sess.run([norm.log_prob(x), sasnorm.log_prob(x)]) + np.testing.assert_array_less(sasnorm_lp, norm_lp) + + # 0.1% quantile and 99.9% quantile are outliers, and should be more + # extreme in the normal. The 97.772% quantiles should be the same. + norm_samps, sasnorm_samps = sess.run( + [norm.sample(int(5e5), seed=1), + sasnorm.sample(int(5e5), seed=1)]) + np.testing.assert_array_less( + np.percentile(norm_samps, 0.1, axis=0), + np.percentile(sasnorm_samps, 0.1, axis=0)) + np.testing.assert_array_less( + np.percentile(sasnorm_samps, 99.9, axis=0), + np.percentile(norm_samps, 99.9, axis=0)) + # 100. * sp.stats.norm.cdf(2.) + q = 100 * 0.97724986805182079 + self.assertAllClose( + np.percentile(sasnorm_samps, q, axis=0), + np.percentile(norm_samps, q, axis=0), + rtol=0.03) + self.assertAllClose( + np.percentile(sasnorm_samps, 100 - q, axis=0), + np.percentile(norm_samps, 100 - q, axis=0), + rtol=0.03) + + def test_tailweight_large_gives_more_outliers_than_normal(self): + batch_size = 10 + scale = rng.rand(batch_size) + 0.5 + loc = np.float64(0.) + with self.test_session() as sess: + norm = ds.Normal( + loc=loc, + scale=scale, + validate_args=True) + sasnorm = ds.SinhArcsinh( + loc=loc, + scale=scale, + tailweight=3., + validate_args=True) + + # norm.pdf(x) is smaller on outliers (+-10 are outliers) + x = np.float64([[-10] * batch_size, [10] * batch_size]) # Shape [2, 10] + norm_lp, sasnorm_lp = sess.run([norm.log_prob(x), sasnorm.log_prob(x)]) + np.testing.assert_array_less(norm_lp, sasnorm_lp) + + # 0.1% quantile and 99.9% quantile are outliers, and should be more + # extreme in the sasnormal. The 97.772% quantiles should be the same. + norm_samps, sasnorm_samps = sess.run( + [norm.sample(int(5e5), seed=2), + sasnorm.sample(int(5e5), seed=2)]) + np.testing.assert_array_less( + np.percentile(sasnorm_samps, 0.1, axis=0), + np.percentile(norm_samps, 0.1, axis=0)) + np.testing.assert_array_less( + np.percentile(norm_samps, 99.9, axis=0), + np.percentile(sasnorm_samps, 99.9, axis=0)) + # 100. * sp.stats.norm.cdf(2.) + q = 100 * 0.97724986805182079 + self.assertAllClose( + np.percentile(sasnorm_samps, q, axis=0), + np.percentile(norm_samps, q, axis=0), + rtol=0.03) + self.assertAllClose( + np.percentile(sasnorm_samps, 100 - q, axis=0), + np.percentile(norm_samps, 100 - q, axis=0), + rtol=0.03) + + def test_positive_skewness_moves_mean_to_the_right(self): + batch_size = 10 + scale = rng.rand(batch_size) + 0.5 + loc = rng.randn(batch_size) + with self.test_session() as sess: + sasnorm = ds.SinhArcsinh( + loc=loc, + scale=scale, + skewness=3.0, + validate_args=True) + + sasnorm_samps = sess.run(sasnorm.sample(10000, seed=4)) + np.testing.assert_array_less(loc, sasnorm_samps.mean(axis=0)) + + def test_pdf_reflected_for_negative_skewness(self): + with self.test_session() as sess: + sas_pos_skew = ds.SinhArcsinh( + loc=0., + scale=1., + skewness=2., + validate_args=True) + sas_neg_skew = ds.SinhArcsinh( + loc=0., + scale=1., + skewness=-2., + validate_args=True) + x = np.linspace(-2, 2, num=5).astype(np.float32) + self.assertAllClose( + *sess.run([sas_pos_skew.prob(x), sas_neg_skew.prob(x[::-1])])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 6269dc5d72f7c9003b60a1635545b91502087e3a..4001530f6654a656891ebc15397cc3f618711bd8 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -80,6 +80,42 @@ class TransformedDistributionTest(test.TestCase): with self.test_session(graph=g): self.assertAllClose(expected, actual.eval(), atol=0, rtol=0.01) + def testNonInjectiveTransformedDistribution(self): + g = ops.Graph() + with g.as_default(): + mu = 1. + sigma = 2.0 + abs_normal = self._cls()( + distribution=ds.Normal(loc=mu, scale=sigma), + bijector=bs.AbsoluteValue(event_ndims=0)) + sp_normal = stats.norm(mu, sigma) + + # sample + sample = abs_normal.sample(100000, seed=235) + self.assertAllEqual([], abs_normal.event_shape) + with self.test_session(graph=g): + sample_ = sample.eval() + self.assertAllEqual([], abs_normal.event_shape_tensor().eval()) + + # Abs > 0, duh! + np.testing.assert_array_less(0, sample_) + + # Let X ~ Normal(mu, sigma), Y := |X|, then + # P[Y < 0.77] = P[-0.77 < X < 0.77] + self.assertAllClose( + sp_normal.cdf(0.77) - sp_normal.cdf(-0.77), + (sample_ < 0.77).mean(), rtol=0.01) + + # p_Y(y) = p_X(-y) + p_X(y), + self.assertAllClose( + sp_normal.pdf(1.13) + sp_normal.pdf(-1.13), + abs_normal.prob(1.13).eval()) + + # Log[p_Y(y)] = Log[p_X(-y) + p_X(y)] + self.assertAllClose( + np.log(sp_normal.pdf(2.13) + sp_normal.pdf(-2.13)), + abs_normal.log_prob(2.13).eval()) + def testCachedSamples(self): exp_forward_only = bs.Exp(event_ndims=0) exp_forward_only._inverse = self._make_unimplemented( @@ -172,6 +208,19 @@ class TransformedDistributionTest(test.TestCase): self.assertAllClose(actual_mvn_entropy, fake_mvn.entropy().eval()) + def testScalarBatchScalarEventIdentityScale(self): + with self.test_session() as sess: + exp2 = self._cls()( + ds.Exponential(rate=0.25), + bijector=ds.bijectors.Affine( + scale_identity_multiplier=2., + event_ndims=0)) + log_prob = exp2.log_prob(1.) + log_prob_ = sess.run(log_prob) + base_log_prob = -0.5 * 0.25 + np.log(0.25) + ildj = np.log(2.) + self.assertAllClose(base_log_prob - ildj, log_prob_, rtol=1e-6, atol=0.) + class ScalarToMultiTest(test.TestCase): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py index 070ee61be314905239e11e8ed3b39f6ffa7510a7..aea4d4250383f5a6ae1af5545e06db08ac3788a3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py @@ -22,9 +22,9 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_diag as linop_diag_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_identity as linop_identity_lib from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib +from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py index a7140cd98b4fbf2b6fc6c505a884619957e6eef1..a5d837d4541b63922aea2fcdf648898b391c662d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py @@ -251,6 +251,22 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers, center=0.15, rtol=0.1) + def test_pdf_reflected_for_negative_skewness(self): + with self.test_session() as sess: + sas_pos_skew = ds.VectorSinhArcsinhDiag( + loc=[0.], + scale_identity_multiplier=1., + skewness=2., + validate_args=True) + sas_neg_skew = ds.VectorSinhArcsinhDiag( + loc=[0.], + scale_identity_multiplier=1., + skewness=-2., + validate_args=True) + x = np.linspace(-2, 2, num=5).astype(np.float32).reshape(5, 1) + self.assertAllClose( + *sess.run([sas_pos_skew.prob(x), sas_neg_skew.prob(x[::-1])])) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 5196954aea2b48964b9a89ef217d74c7b6dd88df..e62f900bbfe798d8981695119e9be1afd88d9281 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Bijector Ops. +@@AbsoluteValue @@Affine @@AffineLinearOperator @@Bijector @@ -21,9 +22,11 @@ @@CholeskyOuterProduct @@ConditionalBijector @@Exp +@@Gumbel @@Identity @@Inline @@Invert +@@Permute @@PowerTransform @@Sigmoid @@SigmoidCentered @@ -39,14 +42,17 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member +from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import * from tensorflow.contrib.distributions.python.ops.bijectors.affine import * from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import * from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * +from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * +from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py new file mode 100644 index 0000000000000000000000000000000000000000..6049419818e18c54209f0be95d41fcecf6627b7e --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AbsoluteValue bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["AbsoluteValue"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..b84502003ab6c0c4ffdda21eea162f441509e1fa --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py @@ -0,0 +1,132 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AbsoluteValue bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + +__all__ = [ + "AbsoluteValue", +] + + +class AbsoluteValue(bijector.Bijector): + """Computes `Y = g(X) = Abs(X)`, element-wise. + + This non-injective bijector allows for transformations of scalar distributions + with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. + + * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse + `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. + * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse + (the set inverse is the singleton `{0}`), but "works" in conjunction with + `TransformedDistribution` to produce a left semi-continuous pdf. + * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the + wrong thing, `-y, y`. This is done for efficiency. If + `validate_args == True`, `y < 0` will raise an exception. + + + ```python + abs = ds.bijectors.AbsoluteValue() + + abs.forward([-1., 0., 1.]) + ==> [1., 0., 1.] + + abs.inverse(1.) + ==> [-1., 1.] + + # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. + abs.inverse_log_det_jacobian(1.) + ==> [0., 0.] + + # Special case handling of 0. + abs.inverse(0.) + ==> [0., 0.] + + abs.inverse_log_det_jacobian(0.) + ==> [0., 0.] + ``` + + """ + + def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): + """Instantiates the `AbsoluteValue` bijector. + + Args: + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. Currently only zero is + supported. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness, in particular whether inputs to `inverse` and + `inverse_log_det_jacobian` are non-negative. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: If `event_ndims` is not zero. + """ + self._graph_parents = [] + self._name = name + + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0,): + raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) + else: + if validate_args: + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_equal( + event_ndims, 0, message="event_ndims was not 0")], + event_ndims) + + with self._name_scope("init"): + super(AbsoluteValue, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return math_ops.abs(x) + + def _inverse(self, y): + if self.validate_args: + y = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + y) + return -y, y + + def _inverse_log_det_jacobian(self, y): + # If event_ndims = 2, + # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), + # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. + batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] + zeros = array_ops.zeros(batch_shape, dtype=y.dtype) + if self.validate_args: + zeros = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + zeros) + return zeros, zeros + + @property + def _is_injective(self): + return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py index d8698788c141328e72651e958d9e6368d33f6aaf..05bb9c2f9bdf35e222c94db3491157893da64ebd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py @@ -199,6 +199,11 @@ class Affine(bijector.Bijector): event_ndims, 2, message="event_ndims must be 0 or 1")], event_ndims) + if event_ndims_const == 0 and not self._is_only_identity_multiplier: + raise ValueError( + "If event_ndims == 0, the only scale argument you can pass is " + "scale_identity_multiplier. All others operate on vectors.") + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. dtype = dtypes.float32 @@ -321,7 +326,7 @@ class Affine(bijector.Bijector): shape_hint=shape_hint) if perturb_factor is not None: - return linalg.LinearOperatorUDVHUpdate( + return linalg.LinearOperatorLowRankUpdate( scale, u=perturb_factor, diag_update=perturb_diag, @@ -383,10 +388,11 @@ class Affine(bijector.Bijector): if self._is_only_identity_multiplier: # We don't pad in this case and instead let the fldj be applied # via broadcast. - d = math_ops.cast(array_ops.shape(x)[-1], dtype=self._scale.dtype) - one = ops.convert_to_tensor(1., self._scale.dtype) - return math_ops.log(math_ops.abs(self._scale)) * array_ops.where( - math_ops.equal(self._shaper.event_ndims, 0), one, d) + event_size = distribution_util.pick_vector( + math_ops.equal(self._shaper.event_ndims, 0), + [1], array_ops.shape(x))[-1] + event_size = math_ops.cast(event_size, dtype=self._scale.dtype) + return math_ops.log(math_ops.abs(self._scale)) * event_size return self.scale.log_abs_determinant() def _maybe_check_scale(self): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py index ae380b5cb2bc39e06aa1e187c134d7e92f6cd92f..89043b1410370074f11f2cfa59b6b6663fa62521 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.contrib.linalg.python.ops import linear_operator from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -27,6 +26,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.linalg import linear_operator __all__ = [ @@ -66,7 +66,7 @@ class AffineLinearOperator(bijector.Bijector): Example Use: ```python - linalg = tf.contrib.linalg + linalg = tf.linalg x = [1., 2, 3] @@ -82,7 +82,7 @@ class AffineLinearOperator(bijector.Bijector): tril = [[1., 0, 0], [2, 1, 0], [3, 2, 1]] - scale = linalg.LinearOperatorTriL(tril) + scale = linalg.LinearOperatorLowerTriangular(tril) affine = AffineLinearOperator(shift, scale) # In this case, `forward` is equivalent to: # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py index defa36a14048d35c6264c7227840ed70dcc77cbb..3ce7c26213034c7345a20faa803c94a1bfa8d579 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py @@ -81,6 +81,13 @@ class Chain(bijector.Bijector): if bijectors is None: bijectors = () self._bijectors = bijectors + + for a_bijector in bijectors: + if not a_bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijector ({})".format( + a_bijector.name)) + dtype = list(set([b.dtype for b in bijectors])) if len(dtype) > 2: raise ValueError("incompatible dtypes: %s" % dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py new file mode 100644 index 0000000000000000000000000000000000000000..cf37aa51115ed98ab263bc03bcb297a03432a7ae --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gumbel bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.distributions.python.ops.bijectors.gumbel_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["Gumbel"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..67f39785563255be0fe154aca3cbcf01c6a01e73 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py @@ -0,0 +1,124 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gumbel bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + +__all__ = [ + "Gumbel", +] + + +class Gumbel(bijector.Bijector): + """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + + This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): + + ```none + Y ~ Gumbel(loc, scale) + pdf(y; loc, scale) = exp( + -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale + ``` + """ + + def __init__(self, + loc=0., + scale=1., + event_ndims=0, + validate_args=False, + name="gumbel"): + """Instantiates the `Gumbel` bijector. + + Args: + loc: Float-like `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + scale: Positive Float-like `Tensor` that is the same dtype and is + broadcastable with `loc`. + This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[loc, scale]): + self._loc = ops.convert_to_tensor(loc, name="loc") + self._scale = ops.convert_to_tensor(scale, name="scale") + check_ops.assert_same_float_dtype([self._loc, self._scale]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, message="Argument scale was not positive") + ], self._scale) + + super(Gumbel, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def loc(self): + """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._loc + + @property + def scale(self): + """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._scale + + def _forward(self, x): + z = (x - self.loc) / self.scale + return math_ops.exp(-math_ops.exp(-z)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.loc - self.scale * math_ops.log(-math_ops.log(y)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + event_dims = self._event_dims_tensor(x) + z = (x - self.loc) / self.scale + return math_ops.reduce_sum( + -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, + constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py index 1d0719e6a4574864ba64019b122562819606435c..2c603fe61f36dd27f4984fe6c13c11f2fb534321 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py @@ -60,6 +60,10 @@ class Invert(bijector_lib.Bijector): name: Python `str`, name given to ops managed by this object. """ + if not bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijectors.") + self._bijector = bijector super(Invert, self).__init__( event_ndims=bijector.event_ndims, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py new file mode 100644 index 0000000000000000000000000000000000000000..a187ce22d686ee1203802ae2bfe64b0e1a3ea850 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Permute bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.distributions.python.ops.bijectors.permute_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["Permute"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d8f2f41b28a88208a19824377f93882b767f03 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py @@ -0,0 +1,138 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Permutation bijectors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ + "Permute", +] + + +class Permute(bijector_lib.Bijector): + """Permutes the rightmost dimension of a `Tensor`. + + ```python + bs = tf.contrib.distributions.bijectors + + reverse = bs.Permute(permutation=[2, 1, 0]) + + reverse.forward([-1., 0., 1.]) + # ==> [1., 0., -1] + + reverse.inverse([1., 0., -1]) + # ==> [-1., 0., 1.] + + reverse.forward_log_det_jacobian(any_value) + # ==> 0. + + reverse.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + Warning: `tf.estimator` may repeatedly build the graph thus + `Permute(np.random.permutation(event_size)).astype("int32"))` is not a + reliable parameterization (nor would it be even if using `tf.constant`). A + safe alternative is to use `tf.get_variable` to achieve "init once" behavior, + i.e., + + ```python + def init_once(x, name): + return tf.get_variable(name, initializer=x, trainable=False) + + Permute(permutation=init_once( + np.random.permutation(event_size).astype("int32"), + name="permutation")) + ``` + + """ + + def __init__(self, permutation, validate_args=False, name=None): + """Creates the `Permute` bijector. + + Args: + permutation: An `int`-like vector-shaped `Tensor` representing the + permutation to apply to the rightmost dimension of the transformed + `Tensor`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if `not permutation.dtype.is_integer`. + ValueError: if `permutation` does not contain exactly one of each of + `{0, 1, ..., d}`. + """ + with ops.name_scope(name, "permute", values=[permutation]): + permutation = ops.convert_to_tensor( + permutation, + name="permutation") + if not permutation.dtype.is_integer: + raise TypeError("permutation.dtype ({}) should be `int`-like.".format( + permutation.dtype.name)) + p = tensor_util.constant_value(permutation) + if p is not None: + if set(p) != set(np.arange(p.size)): + raise ValueError("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.") + elif validate_args: + p, _ = nn_ops.top_k(-permutation, + k=array_ops.shape(permutation)[-1], + sorted=True) + permutation = control_flow_ops.with_dependencies([ + check_ops.assert_equal( + -p, math_ops.range(array_ops.size(p)), + message=("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.")), + ], permutation) + self._permutation = permutation + super(Permute, self).__init__( + is_constant_jacobian=True, + validate_args=validate_args, + name=name or "permute") + + @property + def permutation(self): + return self._permutation + + def _forward(self, x): + return array_ops.gather(x, self.permutation, axis=-1) + + def _inverse(self, y): + return array_ops.gather( + y, + array_ops.invert_permutation(self.permutation), + axis=-1) + + def _inverse_log_det_jacobian(self, y): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py index dac3d812eef28b6aed291db051726d0594f7316a..3a75e4ae9495793901b0da91a5aa3982aab35852 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py @@ -89,18 +89,18 @@ class SinhArcsinh(bijector.Bijector): """ def __init__(self, - skewness=0., - tailweight=1., + skewness=None, + tailweight=None, event_ndims=0, validate_args=False, - name="sinh_arcsinh"): + name="SinhArcsinh"): """Instantiates the `SinhArcsinh` bijector. Args: - skewness: Skewness parameter. Float-type `Tensor`. + skewness: Skewness parameter. Float-type `Tensor`. Default is `0` + of type `float32`. tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as - `skewness` - and broadcastable `shape`. + `skewness` and broadcastable `shape`. Default is `1` of type `float32`. event_ndims: Python scalar indicating the number of dimensions associated with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be @@ -111,8 +111,12 @@ class SinhArcsinh(bijector.Bijector): self._name = name self._validate_args = validate_args with self._name_scope("init", values=[skewness, tailweight]): - self._skewness = ops.convert_to_tensor(skewness, name="skewness") - self._tailweight = ops.convert_to_tensor(tailweight, name="tailweight") + tailweight = 1. if tailweight is None else tailweight + skewness = 0. if skewness is None else skewness + self._skewness = ops.convert_to_tensor( + skewness, name="skewness") + self._tailweight = ops.convert_to_tensor( + tailweight, name="tailweight", dtype=self._skewness.dtype) check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) if validate_args: self._tailweight = control_flow_ops.with_dependencies([ diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index db20d170e1aafef02ec19f7429dcf3e1b9a804f6..f1b7bf468e92913e6d1d5dd965de9c3dc220f9ed 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -106,6 +106,17 @@ class ConditionalTransformedDistribution( distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs) + if self.bijector._is_injective: # pylint: disable=protected-access + return self._finish_log_prob_for_one_fiber(y, x, ildj, + distribution_kwargs) + + lp_on_fibers = [ + self._finish_log_prob_for_one_fiber(y, x_i, ildj_i, distribution_kwargs) + for x_i, ildj_i in zip(x, ildj)] + return math_ops.reduce_logsumexp(array_ops.stack(lp_on_fibers), axis=0) + + def _finish_log_prob_for_one_fiber(self, y, x, ildj, distribution_kwargs): + """Finish computation of log_prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) log_prob = self.distribution.log_prob(x, **distribution_kwargs) if self._is_maybe_event_override: @@ -118,6 +129,16 @@ class ConditionalTransformedDistribution( distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs) + if self.bijector._is_injective: # pylint: disable=protected-access + return self._finish_prob_for_one_fiber(y, x, ildj, distribution_kwargs) + + prob_on_fibers = [ + self._finish_prob_for_one_fiber(y, x_i, ildj_i, distribution_kwargs) + for x_i, ildj_i in zip(x, ildj)] + return sum(prob_on_fibers) + + def _finish_prob_for_one_fiber(self, y, x, ildj, distribution_kwargs): + """Finish computation of prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) prob = self.distribution.prob(x, **distribution_kwargs) if self._is_maybe_event_override: @@ -129,6 +150,9 @@ class ConditionalTransformedDistribution( if self._is_maybe_event_override: raise NotImplementedError("log_cdf is not implemented when overriding " "event_shape") + if not self.bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError("log_cdf is not implemented when " + "bijector is not injective.") bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) @@ -139,6 +163,9 @@ class ConditionalTransformedDistribution( if self._is_maybe_event_override: raise NotImplementedError("cdf is not implemented when overriding " "event_shape") + if not self.bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError("cdf is not implemented when " + "bijector is not injective.") bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) @@ -150,6 +177,9 @@ class ConditionalTransformedDistribution( if self._is_maybe_event_override: raise NotImplementedError("log_survival_function is not implemented when " "overriding event_shape") + if not self.bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError("log_survival_function is not implemented when " + "bijector is not injective.") bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) @@ -161,6 +191,9 @@ class ConditionalTransformedDistribution( if self._is_maybe_event_override: raise NotImplementedError("survival_function is not implemented when " "overriding event_shape") + if not self.bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError("survival_function is not implemented when " + "bijector is not injective.") bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index b5e3decd6c966c67a6d3d1b341f78272acd3fec5..869b5698e57d199755ce1686a74a1eafe3b73e7d 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -160,7 +160,7 @@ def make_tril_scale( scale_tril = array_ops.matrix_set_diag(scale_tril, tril_diag) - return linalg.LinearOperatorTriL( + return linalg.LinearOperatorLowerTriangular( tril=_maybe_attach_assertion(scale_tril), is_non_singular=True, is_self_adjoint=False, @@ -378,6 +378,30 @@ def prefer_static_broadcast_shape( return array_ops.broadcast_dynamic_shape(shape1_, shape2_) +def get_broadcast_shape(*tensors): + """Get broadcast shape as a Python list of integers (preferred) or `Tensor`. + + Args: + *tensors: One or more `Tensor` objects (already converted!). + + Returns: + broadcast shape: Python list (if shapes determined statically), otherwise + an `int32` `Tensor`. + """ + # Try static. + s_shape = tensors[0].shape + for t in tensors[1:]: + s_shape = array_ops.broadcast_static_shape(s_shape, t.shape) + if s_shape.is_fully_defined(): + return s_shape.as_list() + + # Fallback on dynamic. + d_shape = array_ops.shape(tensors[0]) + for t in tensors[1:]: + d_shape = array_ops.broadcast_dynamic_shape(d_shape, array_ops.shape(t)) + return d_shape + + def is_diagonal_scale(scale): """Returns `True` if `scale` is a `LinearOperator` that is known to be diag. diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index e92bcf8c1fda3e550616abada2879e355866c055..5558ef0f255db684b229d129666634e50c625887 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -260,6 +260,14 @@ class MixtureSameFamily(distribution.Distribution): probs * self.components_distribution.mean(), axis=-1 - self._event_ndims) # [B, E] + def _log_cdf(self, x): + x = self._pad_sample_dims(x) + log_cdf_x = self.components_distribution.log_cdf(x) # [S, B, k] + log_mix_prob = nn_ops.log_softmax( + self.mixture_distribution.logits, dim=-1) # [B, k] + return math_ops.reduce_logsumexp( + log_cdf_x + log_mix_prob, axis=-1) # [S, B] + def _variance(self): with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index ee3e02e0203a3338b7e6a40b7e3ff30c0a0940f0..040bc230722194316b8a74627344e315a2578281 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -237,7 +237,7 @@ class MultivariateNormalDiagPlusLowRank( scale_perturb_diag, name="scale_perturb_diag") if has_low_rank: - scale = linalg.LinearOperatorUDVHUpdate( + scale = linalg.LinearOperatorLowRankUpdate( scale, u=scale_perturb_factor, diag_update=scale_perturb_diag, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 221eed547bacd59d3c0d065f386fe45970f9bae9..f9952b2069d6dfd2593e6bd71ede0badf44cdf98 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -174,8 +174,8 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): covariance_matrix = control_flow_ops.with_dependencies( [assert_symmetric], covariance_matrix) # No need to validate that covariance_matrix is non-singular. - # LinearOperatorTriL has an assert_non_singular method that is called - # by the Bijector. + # LinearOperatorLowerTriangular has an assert_non_singular method that + # is called by the Bijector. # However, cholesky() ignores the upper triangular part, so we do need # to separately assert symmetric. scale_tril = linalg_ops.cholesky(covariance_matrix) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 50c7ba418be5b66127a3fde9f02a39b8f52ff841..251c2dbdfa59135be92afca30de88f23b2a40b4d 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.bijectors import AffineLinearOperator from tensorflow.python.framework import ops @@ -28,6 +27,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.linalg import linalg __all__ = [ @@ -92,7 +92,7 @@ class MultivariateNormalLinearOperator( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -106,7 +106,7 @@ class MultivariateNormalLinearOperator( mvn = ds.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorTriL(scale)) + scale=la.LinearOperatorLowerTriangular(scale)) # Covariance agrees with cholesky(cov) parameterization. mvn.covariance().eval() @@ -243,8 +243,8 @@ class MultivariateNormalLinearOperator( def _variance(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.square(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense())) else: @@ -254,8 +254,8 @@ class MultivariateNormalLinearOperator( def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.abs(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return math_ops.sqrt(array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense()))) else: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index 48c4dddc8133d408e1beb7a8aef2abd676895fe3..e3d68f6b4c0d8837e42c8f0a20d8c711bb21c9d6 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -188,9 +188,9 @@ class MultivariateNormalTriL( assert_proper_shapes=validate_args) else: # No need to validate that scale_tril is non-singular. - # LinearOperatorTriL has an assert_non_singular method that is called - # by the Bijector. - scale = linalg.LinearOperatorTriL( + # LinearOperatorLowerTriangular has an assert_non_singular + # method that is called by the Bijector. + scale = linalg.LinearOperatorLowerTriangular( scale_tril, is_non_singular=True, is_self_adjoint=False, diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py new file mode 100644 index 0000000000000000000000000000000000000000..b05f15771a3a94779ffddea8f16ad2fa4ea2fdd1 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -0,0 +1,217 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SinhArcsinh transformation of a distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.distributions import normal +from tensorflow.python.ops.distributions import transformed_distribution + +__all__ = [ + "SinhArcsinh", +] + + +class SinhArcsinh(transformed_distribution.TransformedDistribution): + """The SinhArcsinh transformation of a distribution on `(-inf, inf)`. + + This distribution models a random variable, making use of + a `SinhArcsinh` transformation (which has adjustable tailweight and skew), + a rescaling, and a shift. + + The `SinhArcsinh` transformation of the Normal is described in great depth in + [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865). + Here we use a slightly different parameterization, in terms of `tailweight` + and `skewness`. Additionally we allow for distributions other than Normal, + and control over `scale` as well as a "shift" parameter `loc`. + + #### Mathematical Details + + Given random variable `Z`, we define the SinhArcsinh + transformation of `Z`, `Y`, parameterized by + `(loc, scale, skewness, tailweight)`, via the relation: + + ``` + Y := loc + scale * F(Z) * (2 / F_0(2)) + F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) + F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) + ``` + + This distribution is similar to the location-scale transformation + `L(Z) := loc + scale * Z` in the following ways: + + * If `skewness = 0` and `tailweight = 1` (the defaults), `F(Z) = Z`, and then + `Y = L(Z)` exactly. + * `loc` is used in both to shift the result by a constant factor. + * The multiplication of `scale` by `2 / F_0(2)` ensures that if `skewness = 0` + `P[Y - loc <= 2 * scale] = P[L(Z) - loc <= 2 * scale]`. + Thus it can be said that the weights in the tails of `Y` and `L(Z)` beyond + `loc + 2 * scale` are the same. + + This distribution is different than `loc + scale * Z` due to the + reshaping done by `F`: + + * Positive (negative) `skewness` leads to positive (negative) skew. + * positive skew means, the mode of `F(Z)` is "tilted" to the right. + * positive skew means positive values of `F(Z)` become more likely, and + negative values become less likely. + * Larger (smaller) `tailweight` leads to fatter (thinner) tails. + * Fatter tails mean larger values of `|F(Z)|` become more likely. + * `tailweight < 1` leads to a distribution that is "flat" around `Y = loc`, + and a very steep drop-off in the tails. + * `tailweight > 1` leads to a distribution more peaked at the mode with + heavier tails. + + To see the argument about the tails, note that for `|Z| >> 1` and + `|Z| >> (|skewness| * tailweight)**tailweight`, we have + `Y approx 0.5 Z**tailweight e**(sign(Z) skewness * tailweight)`. + + To see the argument regarding multiplying `scale` by `2 / F_0(2)`, + + ``` + P[(Y - loc) / scale <= 2] = P[F(Z) * (2 / F_0(2)) <= 2] + = P[F(Z) <= F_0(2)] + = P[Z <= 2] (if F = F_0). + ``` + """ + + def __init__(self, + loc, + scale, + skewness=None, + tailweight=None, + distribution=None, + validate_args=False, + allow_nan_stats=True, + name="SinhArcsinh"): + """Construct SinhArcsinh distribution on `(-inf, inf)`. + + Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape + (indexing batch dimensions). They must all have the same `dtype`. + + Args: + loc: Floating-point `Tensor`. + scale: `Tensor` of same `dtype` as `loc`. + skewness: Skewness parameter. Default is `0.0` (no skew). + tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) + distribution: `tf.Distribution`-like instance. Distribution that is + transformed to produce this distribution. + Default is `ds.Normal(0., 1.)`. + Must be a scalar-batch, scalar-event distribution. Typically + `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is + a function of non-trainable parameters. WARNING: If you backprop through + a `SinhArcsinh` sample and `distribution` is not + `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then + the gradient will be incorrect! + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + parameters = locals() + + with ops.name_scope(name, values=[loc, scale, skewness, tailweight]): + loc = ops.convert_to_tensor(loc, name="loc") + dtype = loc.dtype + scale = ops.convert_to_tensor(scale, name="scale", dtype=dtype) + tailweight = 1. if tailweight is None else tailweight + has_default_skewness = skewness is None + skewness = 0. if skewness is None else skewness + tailweight = ops.convert_to_tensor( + tailweight, name="tailweight", dtype=dtype) + skewness = ops.convert_to_tensor(skewness, name="skewness", dtype=dtype) + + batch_shape = distribution_util.get_broadcast_shape( + loc, scale, tailweight, skewness) + + # Recall, with Z a random variable, + # Y := loc + C * F(Z), + # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) + # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) + # C := 2 * scale / F_0(2) + if distribution is None: + distribution = normal.Normal( + loc=array_ops.zeros([], dtype=dtype), + scale=array_ops.ones([], dtype=dtype), + allow_nan_stats=allow_nan_stats) + else: + asserts = distribution_util.maybe_check_scalar_distribution( + distribution, dtype, validate_args) + if asserts: + loc = control_flow_ops.with_dependencies(asserts, loc) + + # Make the SAS bijector, 'F'. + f = bijectors.SinhArcsinh( + skewness=skewness, tailweight=tailweight, event_ndims=0) + if has_default_skewness: + f_noskew = f + else: + f_noskew = bijectors.SinhArcsinh( + skewness=skewness.dtype.as_numpy_dtype(0.), + tailweight=tailweight, event_ndims=0) + + # Make the Affine bijector, Z --> loc + scale * Z (2 / F_0(2)) + c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype)) + affine = bijectors.Affine( + shift=loc, + scale_identity_multiplier=c, + validate_args=validate_args, + event_ndims=0) + + bijector = bijectors.Chain([affine, f]) + + super(SinhArcsinh, self).__init__( + distribution=distribution, + bijector=bijector, + batch_shape=batch_shape, + validate_args=validate_args, + name=name) + self._parameters = parameters + self._loc = loc + self._scale = scale + self._tailweight = tailweight + self._skewness = skewness + + @property + def loc(self): + """The `loc` in `Y := loc + scale @ F(Z) * (2 / F(2)).""" + return self._loc + + @property + def scale(self): + """The `LinearOperator` `scale` in `Y := loc + scale @ F(Z) * (2 / F(2)).""" + return self._scale + + @property + def tailweight(self): + """Controls the tail decay. `tailweight > 1` means faster than Normal.""" + return self._tailweight + + @property + def skewness(self): + """Controls the skewness. `Skewness > 0` means right skew.""" + return self._skewness diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 6d297ea1f11398bb6abcb73aef2fce15bf7b429f..438d628da481d387f74b40fab3de62349061668c 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -23,10 +23,6 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator from tensorflow.contrib.linalg.python.ops import linear_operator_addition as linop_add_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_diag as linop_diag_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_full_matrix as linop_full_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_identity as linop_identity_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_tril as linop_tril_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -37,6 +33,10 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib +from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib +from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib +from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib static_value = distribution_util.static_value @@ -185,7 +185,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and # another with mix_loc=[1]. In both cases, `K=2` and the affine @@ -772,8 +772,8 @@ def linop_scale(w, op): is_non_singular=op.is_non_singular, is_self_adjoint=op.is_self_adjoint, is_positive_definite=op.is_positive_definite) - if isinstance(op, linop_tril_lib.LinearOperatorTriL): - return linop_tril_lib.LinearOperatorTriL( + if isinstance(op, linop_tril_lib.LinearOperatorLowerTriangular): + return linop_tril_lib.LinearOperatorLowerTriangular( tril=w[..., array_ops.newaxis, array_ops.newaxis] * op.to_dense(), is_non_singular=op.is_non_singular, is_self_adjoint=op.is_self_adjoint, diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index c88572e17fa43ac11778bdddc02484d284b6eb36..356d78b67a8107750f68f7f84d73d1231f5b2b03 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -90,7 +90,7 @@ class VectorExponentialDiag( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index 7123165417ea010fa9da5263e429734d34df3dbd..b313a851b381e5b3a057fd17e6c2ef4eb0fc34f1 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.python.framework import ops @@ -26,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import exponential from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.linalg import linalg __all__ = ["VectorExponentialLinearOperator"] @@ -108,7 +108,7 @@ class VectorExponentialLinearOperator( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. @@ -247,7 +247,7 @@ class VectorExponentialLinearOperator( def _variance(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.square(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) and + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and self.scale.is_self_adjoint): return array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense())) @@ -258,7 +258,7 @@ class VectorExponentialLinearOperator( def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.abs(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) and + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and self.scale.is_self_adjoint): return math_ops.sqrt( array_ops.matrix_diag_part(self.scale.matmul(self.scale.to_dense()))) diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index fdee57695e4e598929396ee4c9fe9f8014ea0f8b..c7abdbb4caf9bee4cbd5991eb5d652f20dd0f8d1 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.python.framework import ops @@ -28,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import laplace from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.linalg import linalg __all__ = [ @@ -110,7 +110,7 @@ class VectorLaplaceLinearOperator( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 3-variate VectorLaplace with some desired covariance. mu = [1., 2, 3] @@ -126,7 +126,7 @@ class VectorLaplaceLinearOperator( # Divide scale by sqrt(2) so that the final covariance will be what we want. vla = ds.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorTriL(scale / tf.sqrt(2))) + scale=la.LinearOperatorLowerTriangular(scale / tf.sqrt(2))) # Covariance agrees with cholesky(cov) parameterization. vla.covariance().eval() @@ -271,8 +271,8 @@ class VectorLaplaceLinearOperator( def _variance(self): if distribution_util.is_diagonal_scale(self.scale): return 2. * math_ops.square(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return array_ops.matrix_diag_part( 2. * self.scale.matmul(self.scale.to_dense())) else: @@ -282,8 +282,8 @@ class VectorLaplaceLinearOperator( def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): return np.sqrt(2) * math_ops.abs(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return np.sqrt(2) * math_ops.sqrt(array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense()))) else: diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 5b3208ca79fd5bc5ac2a1edae8c3959d89b94d79..544a8710709a0afb56c6ae6f36d35de892e8e420 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""SinhArcsinh transformation of a distribution.""" +"""Multi-dimensional (Vector) SinhArcsinh transformation of a distribution.""" from __future__ import absolute_import from __future__ import division @@ -52,8 +52,9 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): matrix multiplication): ``` - Y := loc + scale @ F(Z) * (2 / F(2)) + Y := loc + scale @ F(Z) * (2 / F_0(2)) F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) + F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) ``` This distribution is similar to the location-scale transformation @@ -62,12 +63,12 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): * If `skewness = 0` and `tailweight = 1` (the defaults), `F(Z) = Z`, and then `Y = L(Z)` exactly. * `loc` is used in both to shift the result by a constant factor. - * Our definition of `C` ensures that + * The multiplication of `scale` by `2 / F_0(2)` ensures that if `skewness = 0` `P[Y - loc <= 2 * scale] = P[L(Z) - loc <= 2 * scale]`. Thus it can be said that the weights in the tails of `Y` and `L(Z)` beyond `loc + 2 * scale` are the same. - This distribution is different than `loc + diag(scale) @ Z` due to the + This distribution is different than `loc + scale @ Z` due to the reshaping done by `F`: * Positive (negative) `skewness` leads to positive (negative) skew. @@ -85,12 +86,12 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): `|Z| >> (|skewness| * tailweight)**tailweight`, we have `Y approx 0.5 Z**tailweight e**(sign(Z) skewness * tailweight)`. - To see the argument about `C` and quantiles, note that + To see the argument regarding multiplying `scale` by `2 / F_0(2)`, ``` - P[(Y - loc) / scale <= 2] = P[F(Z) <= 2 * scale / C] - = P[Z <= F^{-1}(2 * scale / C)] - = P[Z <= 2]. + P[(Y - loc) / scale <= 2] = P[F(Z) * (2 / F_0(2)) <= 2] + = P[F(Z) <= F_0(2)] + = P[Z <= 2] (if F = F_0). ``` """ @@ -171,12 +172,14 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): ]): loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc tailweight = 1. if tailweight is None else tailweight + has_default_skewness = skewness is None skewness = 0. if skewness is None else skewness - # Recall, with Z ~ Normal(0, 1), + # Recall, with Z a random variable, # Y := loc + C * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) - # C := 2 * scale / F(2) + # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) + # C := 2 * scale / F_0(2) # Construct shapes and 'scale' out of the scale_* and loc kwargs. # scale_linop is only an intermediary to: @@ -213,9 +216,16 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): tailweight, dtype=dtype, name="tailweight") f = bijectors.SinhArcsinh( skewness=skewness, tailweight=tailweight, event_ndims=1) + if has_default_skewness: + f_noskew = f + else: + f_noskew = bijectors.SinhArcsinh( + skewness=skewness.dtype.as_numpy_dtype(0.), + tailweight=tailweight, event_ndims=0) # Make the Affine bijector, Z --> loc + C * Z. - c = 2 * scale_diag_part / f.forward(ops.convert_to_tensor(2, dtype=dtype)) + c = 2 * scale_diag_part / f_noskew.forward( + ops.convert_to_tensor(2, dtype=dtype)) affine = bijectors.Affine( shift=loc, scale_diag=c, validate_args=validate_args, event_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index 9d30ce67197ebdeefc69d9b9979fdad4797bb183..e4ac65012b9c7e3ed5ada3ed75020f3905740156 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -251,8 +251,8 @@ class _WishartLinearOperator(distribution.Distribution): # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for LinearOperatorDiag, each matmul is O(k**2), so - # this complexity is O(nbk**2). For LinearOperatorTriL, each matmul is - # O(k^3) so this step has complexity O(nbk^3). + # this complexity is O(nbk**2). For LinearOperatorLowerTriangular, + # each matmul is O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator.matmul(x) # Undo make batch-op ready. @@ -307,8 +307,8 @@ class _WishartLinearOperator(distribution.Distribution): # Complexity: O(nbM*k) where M is the complexity of the operator solving # a vector system. E.g., for LinearOperatorDiag, each solve is O(k), so - # this complexity is O(nbk**2). For LinearOperatorTriL, each solve is - # O(k**2) so this step has complexity O(nbk^3). + # this complexity is O(nbk**2). For LinearOperatorLowerTriangular, + # each solve is O(k**2) so this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self.scale_operator.solve( scale_sqrt_inv_x_sqrt) @@ -544,7 +544,7 @@ class WishartCholesky(_WishartLinearOperator): super(WishartCholesky, self).__init__( df=df, - scale_operator=linalg.LinearOperatorTriL( + scale_operator=linalg.LinearOperatorLowerTriangular( tril=scale, is_non_singular=True, is_positive_definite=True, @@ -655,7 +655,7 @@ class WishartFull(_WishartLinearOperator): ] if validate_args else [], chol) super(WishartFull, self).__init__( df=df, - scale_operator=linalg.LinearOperatorTriL( + scale_operator=linalg.LinearOperatorLowerTriangular( tril=chol, is_non_singular=True, is_positive_definite=True, diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index dd305a78dcf716b62a444e5f9dcc01c708c522be..0c61630aa8f79e3efd25584478547abd99f30285 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -9,8 +9,12 @@ py_library( name = "tfe", srcs = ["tfe.py"], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":datasets", + ":evaluator", + ":metrics", + ":network", ":saver", ":summary_writer", "//tensorflow/python:framework_ops", @@ -81,6 +85,7 @@ cuda_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", + "//tensorflow/python/eager:graph_callable", "//tensorflow/python:platform_test", "//tensorflow/python:variables", ], @@ -115,6 +120,85 @@ cuda_py_test( ], ) +py_library( + name = "metrics", + srcs = [ + "metrics.py", + "metrics_impl.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers_base", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "metrics_test", + srcs = ["metrics_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":metrics", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "evaluator", + srcs = [ + "evaluator.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + ":datasets", + ":metrics", + ], +) + +py_test( + name = "evaluator_test", + srcs = ["evaluator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":evaluator", + ":metrics", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "network", + srcs = ["network.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:layers_base", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "network_test", + srcs = ["network_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":network", + "//tensorflow/python:constant_op", + "//tensorflow/python:layers", + "//tensorflow/python/eager:test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 9973f4eee227b6c039aab58cfa40efe629fb2312..fb9fabd6c1b48b9e3a4572d4eb8f6546f2f17c43 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -23,6 +23,7 @@ import threading from tensorflow.python.data.util import nest from tensorflow.python.eager import context from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops @@ -62,20 +63,22 @@ class Iterator(object): raise RuntimeError( "{} objects only make sense when eager execution is enabled".format( type(self))) - ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access - self._output_types = dataset.output_types - self._flat_output_types = nest.flatten(dataset.output_types) - self._flat_output_shapes = nest.flatten(dataset.output_shapes) - self._resource = gen_dataset_ops.iterator( - container="", - shared_name=_iterator_shared_name(), - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - gen_dataset_ops.make_iterator(ds_variant, self._resource) + with ops.device("/device:CPU:0"): + ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access + self._output_types = dataset.output_types + self._flat_output_types = nest.flatten(dataset.output_types) + self._flat_output_shapes = nest.flatten(dataset.output_shapes) + self._resource = gen_dataset_ops.iterator( + container="", + shared_name=_iterator_shared_name(), + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + gen_dataset_ops.make_iterator(ds_variant, self._resource) def __del__(self): if self._resource is not None: - resource_variable_ops.destroy_resource_op(self._resource) + with ops.device("/device:CPU:0"): + resource_variable_ops.destroy_resource_op(self._resource) self._resource = None def __iter__(self): @@ -87,10 +90,14 @@ class Iterator(object): def next(self): """Return the next tf.Tensor from the dataset.""" try: - ret = gen_dataset_ops.iterator_get_next( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - return nest.pack_sequence_as(self._output_types, ret) + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + with ops.device("/device:CPU:0"): + ret = gen_dataset_ops.iterator_get_next( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + return nest.pack_sequence_as(self._output_types, ret) except errors.OutOfRangeError: raise StopIteration diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index a2da6b28c6bdbfa0f6a4ed5d303aa4a36b81fc19..076c92e73f7c2a1ebc6dbeac940a8307adc16414 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -19,7 +19,9 @@ from __future__ import print_function from tensorflow.contrib.data import Dataset from tensorflow.contrib.eager.python import datasets from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops class IteratorTest(test.TestCase): @@ -69,6 +71,16 @@ class IteratorTest(test.TestCase): got2 = [x.numpy() for x in datasets.Iterator(ds)] self.assertAllEqual(got1, got2) + def testPyFunc(self): + + def my_map(inp): + return [[x + 1 for x in inp]] + + ds = Dataset.range(4).map( + lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64)) + got = [x.numpy() for x in datasets.Iterator(ds)] + self.assertAllEqual([[1], [2], [3], [4]], got) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..d757e976eeafa36ec5e870cfde0c620a204d7440 --- /dev/null +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -0,0 +1,217 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class Evaluator holds Metrics for the duration of an evaluation run.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.eager.python import datasets +from tensorflow.contrib.eager.python import metrics + + +class Evaluator(object): + """This holds and updates Metrics for the duration of a single eval run. + + Usage: + evaluator = my_model.evaluator() # or MyEvaluator(my_model) + for example_batch in ...: + evaluator(example_batch) + results = evaluator.all_metric_results(optional_summary_writer) + + Or, if you are getting your examples from a tf.data.Dataset, you can use + the evaluate_on_dataset() method. + + Implementers of Evaluators should + (a) Call `add_metric()` and/or `add_evaluator()` in __init__(). + (b) Override the `call()` method. It will be passed the output of the + model's `eval_data()` method, and should call its contained metrics + (treating them as callables) and any child Evaluators (using their + call() method to avoid calling eval_data() again). + + Args: + model: A `Model` object with an `eval_data()` method. + """ + + def __init__(self, model): + self._model = model + self._metrics = {} + self._evaluators = {} + + # ---- API for users ---- + def __call__(self, *args, **kwargs): + """Update metrics with a minibatch of input examples.""" + return self.call(self._model.eval_data(*args, **kwargs)) + + def all_metric_results(self): # TODO(josh11b): Add optional summary_writer. + """Returns dict mapping metric name -> value.""" + results = {} + for name, metric in six.iteritems(self._metrics): + results[name] = metric.result() + for prefix, evaluator in six.iteritems(self._evaluators): + for name, metric in six.iteritems(evaluator._metrics): # pylint: disable=protected-access + results[prefix + "/" + name] = metric.result() + return results + + def evaluate_on_dataset(self, dataset, *args, **kwargs): + """Convenience method for performing an eval on a Dataset.""" + for example in datasets.Iterator(dataset): + self.__call__(example, *args, **kwargs) + # TODO(josh11b): Add optional summary_writer. + return self.all_metric_results() + + # ---- To be implemented by descendants --- + def call(self, eval_data): + """Update metrics using the output of self.model.""" + raise NotImplementedError("Evaluators must define a call member function.") + + # ---- For use by descendants --- + @property + def model(self): + return self._model + + def add_metric(self, metric): + """Add a Metric to be tracked. + + Rule: metrics can only be in one `Evaluator`. + + Args: + metric: A `Metric` object. + + Returns: + The `metric` passed into this function. + + Raises: + RuntimeError: If called before __init__. + TypeError: If `metric` is not of the correct type. + ValueError: If there is a name collision between Metrics. + """ + if not hasattr(self, "_metrics"): + raise RuntimeError( + "Need to call Evaluator.__init__ before adding metrics") + if not isinstance(metric, metrics.Metric): + raise TypeError( + "Evaluator.add_metric() passed type %s, not a tfe.metrics.Metric" % + (type(metric),)) + if metric.name in self._metrics: + if metric is self._metrics[metric.name]: + return metric + raise ValueError( + "Attempt to add two Metrics with the name '%s' to the same Evaluator " + "'%s'" % (metric.name, self.name)) + self._metrics[metric.name] = metric + return metric + + def add_evaluator(self, prefix, evaluator): + """Add a contained `Evaluator`. + + This is for delegating to another `Evaluator`, e.g. for when you have a + model with multiple heads. Users should manually invoke the child + `Evaluator`'s `call` method from their `call` method. + + Args: + prefix: A string. Metrics from `evaluator` are exported with this + prefix and a '/'. + evaluator: An `Evaluator` object. + + Returns: + The value of `evaluator` passed into this function. + + Raises: + RuntimeError: If called before __init__. + TypeError: If `evaluator` is not of the correct type. + ValueError: If an `Evaluator` has already been added with that `prefix`. + """ + if not hasattr(self, "_evaluators"): + raise RuntimeError( + "Need to call Evaluator.__init__ before adding evaluators") + if not isinstance(evaluator, Evaluator): + raise TypeError( + "Evaluator.add_evaluator() passed type %s, not a tfe.Evaluator." % + (type(evaluator),)) + if prefix in self._evaluators: + if evaluator is self._evaluators[prefix]: + return evaluator + raise RuntimeError( + "Attempt to add two Evaluators with the same prefix '%s'." % prefix) + self._evaluators[prefix] = evaluator + return evaluator + + @property + def metric_variables(self): + v = [] + for metric in six.itervalues(self._metrics): + v += metric.variables + for evaluator in six.itervalues(self._evaluators): + v += evaluator.metric_variables + return v + + @property + def metrics(self): + m = [] + for metric in six.itervalues(self._metrics): + m.append(metric) + for evaluator in six.itervalues(self._evaluators): + m += evaluator.metrics + return m + + +class SparseSoftmaxEvaluator(Evaluator): + """Evaluator for a sparse softmax model. + + Computes a standard set of metrics for single-label, multi-class + models. + + Args: + model: A `SparseSoftmaxModel` object or a `Model` whose `eval_data()` + method produces a `dict` containing values for the loss, true + label, predicted class, and optional weights. + loss_key: Optional key for looking up the value of the loss in the + `eval_data()` dict. Defaults to "loss". + label_key: Optional key for looking up the value of the label in the + `eval_data()` dict. Defaults to "label". + predicted_class_key: Optional key for looking up the value of the + predicted class in the `eval_data()` dict. Defaults to "predicted_class". + weights_key: Optional key for looking up the value of the weights + in the `eval_data()` dict. Defaults to "weights". Note that weights + are optional, and default to 1 if not present in `eval_data`. + """ + + def __init__(self, model, loss_key="loss", label_key="label", + predicted_class_key="predicted_class", weights_key="weights"): + super(SparseSoftmaxEvaluator, self).__init__(model) + # TODO(josh11b): Expand this to include everything from the standard + # SparseSoftmax Head. + self.avg_loss = self.add_metric(metrics.Mean("Avg_Loss")) + self.accuracy = self.add_metric(metrics.Accuracy()) + self.loss_key = loss_key + self.label_key = label_key + self.predicted_class_key = predicted_class_key + self.weights_key = weights_key + + def call(self, eval_data): + """Update metrics for `eval_data` dict (described above).""" + weights = eval_data.get(self.weights_key, None) + if weights is None: + self.avg_loss(eval_data[self.loss_key]) + self.accuracy(eval_data[self.label_key], + eval_data[self.predicted_class_key]) + else: + self.avg_loss(eval_data[self.loss_key], weights=weights) + self.accuracy(eval_data[self.label_key], + eval_data[self.predicted_class_key], + weights=weights) diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..099e10e2307b2e3c406ccf847fc8ee2bce9ce407 --- /dev/null +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -0,0 +1,124 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class Evaluator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.eager.python import evaluator +from tensorflow.contrib.eager.python import metrics +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test + + +class IdentityModel(object): + + def eval_data(self, d): + return d + + +class PrefixLModel(object): + + def eval_data(self, d): + return {"l_" + key: d[key] for key in d} + + +class SimpleEvaluator(evaluator.Evaluator): + + def __init__(self, model): + super(SimpleEvaluator, self).__init__(model) + self.mean = self.add_metric(metrics.Mean("mean")) + + def call(self, eval_data): + self.mean(eval_data) + + +class DelegatingEvaluator(evaluator.Evaluator): + + def __init__(self, model): + super(DelegatingEvaluator, self).__init__(model) + self.sub = self.add_evaluator("inner", SimpleEvaluator(model)) + self.mean = self.add_metric(metrics.Mean("outer-mean")) + + def call(self, eval_data): + # Keys here come from PrefixLModel, which adds "l_". + self.mean(eval_data["l_outer"]) + self.sub.call(eval_data["l_inner"]) + + +# pylint: disable=not-callable +class EvaluatorTest(test.TestCase): + + def testSimple(self): + e = SimpleEvaluator(IdentityModel()) + e(3.0) + e([5.0, 7.0, 9.0]) + results = e.all_metric_results() + self.assertEqual(set(["mean"]), set(results.keys())) + self.assertEqual(6.0, results["mean"].numpy()) + + def testComposition(self): + e = DelegatingEvaluator(PrefixLModel()) + e({"inner": 2.0, "outer": 100.0}) + e({"inner": 4.0, "outer": 1000.0}) + results = e.all_metric_results() + self.assertEqual(set(["inner/mean", "outer-mean"]), set(results.keys())) + self.assertEqual(3.0, results["inner/mean"].numpy()) + self.assertEqual(550.0, results["outer-mean"].numpy()) + + def testMetricVariables(self): + e = DelegatingEvaluator(PrefixLModel()) + e({"inner": 2.0, "outer": 100.0}) + prefix_count = {} + for v in e.metric_variables: + p = v.name.split("/")[0] + prefix_count[p] = prefix_count.get(p, 0) + 1 + self.assertEqual({"outer-mean": 2, "mean": 2}, prefix_count) + + def testDataset(self): + e = SimpleEvaluator(IdentityModel()) + ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) + results = e.evaluate_on_dataset(ds) + self.assertEqual(set(["mean"]), set(results.keys())) + self.assertEqual(6.0, results["mean"].numpy()) + + def testModelProperty(self): + m = IdentityModel() + e = SimpleEvaluator(m) + self.assertIs(m, e.model) + + def testMetricsProperty(self): + e = DelegatingEvaluator(PrefixLModel()) + names = set([m.name for m in e.metrics]) + self.assertEqual(set(["outer-mean", "mean"]), names) + + +class SparseSoftmaxEvaluatorTest(test.TestCase): + + def testSimple(self): + e = evaluator.SparseSoftmaxEvaluator(IdentityModel()) + e({e.loss_key: 1.0, e.label_key: 5, e.predicted_class_key: 5}) + e({e.loss_key: [0.0, 3.0, 4.0], + e.label_key: [1, 2, 3], + e.predicted_class_key: [1, 1, 3]}) + results = e.all_metric_results() + self.assertEqual(set(["Avg_Loss", "Accuracy"]), set(results.keys())) + self.assertEqual(2.0, results["Avg_Loss"].numpy()) + self.assertEqual(0.75, results["Accuracy"].numpy()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/metrics.py b/tensorflow/contrib/eager/python/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3100427376ddd480b50d967cf53e7831aaefb2 --- /dev/null +++ b/tensorflow/contrib/eager/python/metrics.py @@ -0,0 +1,26 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Metrics namespace.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint:disable=wildcard-import +from tensorflow.contrib.eager.python.metrics_impl import * +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['Accuracy', 'Mean', 'Metric'] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..63a0f8d9a45dfb12fd1d61a1156b9acf20cf4c81 --- /dev/null +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -0,0 +1,199 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Metrics classes for computing the output of an evaluation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope + + +class Metric(object): + """A metric holds state for aggregating statistics over an evaluation run. + + Users will use Evaluator.add_metric() to add Metric objects to their + evaluation, call them in each step, and then use + Evaluator.all_metric_results() at the end. + + Descendants will implement: + * call(): Should follow this pattern: + if not self.built: + self.var = self.add_variable(...) + self.add_update(self.var.assign_add(...)) + * aggregate(): Adds in the state from a list of metrics of the same type + as `self`. (Default of summing all the variables will be fine for most + descendants.) + * result(): Computes and returns a final value for the metric + from the variables in `self`. + """ + + def __init__(self, name=None): + self.built = False + self._vars = [] + self._updates = [] + self._name = name or self.__class__.__name__ + # TODO(josh11b): Need some way to make sure two Metrics in the same + # Network have distinct names. Maybe we can get a unique name from + # a name/variable scope? + # TODO(josh11b): self._in_graph_mode = context.in_graph_mode() + + # ---- API for users ---- + def __call__(self, *args, **kwargs): + # TODO(josh11b): If self._in_graph_mode is true, make self.call() into a + # graph callable here, so that variable updates happen without requiring + # a separate fetch. + # TODO(josh11b): Do we need a separate build() method to separate + # initialization from each update? If so, how do we get the arguments + # to it? We *could* just pass in *args and **kwargs... + if not self.built: + # TODO(ashankar): Set up container isolation so there is no chance + # distinct metrics objects accidentally share variables. + # TODO(josh11b): Replace things like spaces in self._name to create + # a valid scope name. + with variable_scope.variable_scope( + self._name, use_resource=True, reuse=False): + ret = self.call(*args, **kwargs) + self.built = True + else: + ret = self.call(*args, **kwargs) + return ret + + @property + def name(self): + return self._name + + @property + def variables(self): + return self._vars + + # ---- To be implemented by descendants --- + def call(self, *args, **kwargs): + """Accumulates statistics for the metric.""" + raise NotImplementedError("Metrics must define a call() member function") + + # We can support two different strategies of for doing data-parallel + # distributed metric computations: + # * Put metric variables on the first device and rely on small + # bandwidth needed to do updates. (Doesn't require any particular + # code in Metric implementations.) + # * Ask each type of metric to define an aggregation method to run + # at the end of eval to merge across devices. Note: this is good + # for the use case where they want to record the metric's state + # for each example and then later decide which examples they want + # to aggregate over. (Recommended -- not too much harder and adds + # flexibility over previous option.) + # I'm going with the second strategy since we can define a default + # implementation of aggregate() that will work for most descendants. + def aggregate(self, metrics): + """Adds in the state from a list of metrics. + + Default implementation sums all the metric variables. + + Args: + metrics: A list of metrics with the same type as `self`. + + Raises: + ValueError: If metrics contains invalid data. + """ + for m in metrics: + if type(self) != type(m): # pylint: disable=unidiomatic-typecheck + raise TypeError("All metrics must be the same type, '%s' != '%s'." % + (type(self), type(m))) + # pylint: disable=protected-access + for i in range(len(self._vars)): + if any(m._vars[i].name != self._vars[i].name for m in metrics): + raise ValueError("All metrics must have variables in the same order.") + self._vars[i].assign_add(math_ops.add_n([m._vars[i] for m in metrics])) + # pylint: enable=protected-access + + def result(self): # TODO(josh11b): Add an optional summary_writer parameter. + """Computes and returns a final value for the metric.""" + raise NotImplementedError("Metrics must define a result() member function") + + # ---- For use by descendants --- + def add_variable(self, name, shape=None, dtype=None, initializer=None): + """***Only for use by descendants of Metric***.""" + if self.built: + raise RuntimeError("Can't call add_variable() after a Metric has been " + "built in the first call().") + v = variable_scope.get_variable(name, shape, dtype, initializer, + trainable=False, use_resource=True) + self._vars.append(v) + return v + + +class Mean(Metric): + """Computes the (weighted) mean of the given values.""" + # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? + # Or defaults to type of the input if it is tf.float32, else tf.float64? + + def call(self, values, weights=None): + """Accumulate statistics for computing the mean. + + For example, if values is [1, 3, 5, 7] then the mean is 4. + If the weights were specified as [1, 1, 0, 0] then the mean would be 2. + + Args: + values: Tensor with the per-example value. + weights: Optional weighting of each example. Defaults to 1. + """ + if not self.built: # False only in the first call(). + self.numer = self.add_variable(name="numer", shape=(), + dtype=dtypes.float64, + initializer=init_ops.zeros_initializer) + self.denom = self.add_variable(name="denom", shape=(), + dtype=dtypes.float64, + initializer=init_ops.zeros_initializer) + if weights is None: + self.denom.assign_add( + math_ops.cast(array_ops.size(values), dtypes.float64)) + values = math_ops.reduce_sum(values) + self.numer.assign_add(math_ops.cast(values, dtypes.float64)) + else: + weights = math_ops.cast(weights, dtypes.float64) + self.denom.assign_add(math_ops.reduce_sum(weights)) + values = math_ops.cast(values, dtypes.float64) * weights + self.numer.assign_add(math_ops.reduce_sum(values)) + + def result(self): + return self.numer / self.denom + + +class Accuracy(Mean): + """Calculates how often `predictions` matches `labels`.""" + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + For example, if labels is [1, 2, 3, 4] and predictions is [0, 2, 3, 4] + then the accuracy is 3/4 or .75. If the weights were specified as + [1, 1, 0, 0] then the accuracy would be 1/2 or .5. + + `labels` and `predictions` should have the same shape and type. + + Args: + labels: Tensor with the true labels for each example. One example + per element of the Tensor. + predictions: Tensor with the predicted label for each example. + weights: Optional weighting of each example. Defaults to 1. + """ + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, dtypes.float64) + super(Accuracy, self).call(matches, weights=weights) diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py new file mode 100644 index 0000000000000000000000000000000000000000..089bad5a0e3049543bdc09b571319262a734809f --- /dev/null +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Metrics.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.eager.python import metrics +from tensorflow.python.eager import test + + +class MetricsTest(test.TestCase): + + def testMean(self): + m = metrics.Mean() + m([1, 10, 100]) + m(1000) + m([10000.0, 100000.0]) + self.assertEqual(111111.0/6, m.result().numpy()) + + def testWeightedMean(self): + m = metrics.Mean() + m([1, 100, 100000], weights=[1, 0.2, 0.3]) + m([500000, 5000, 500]) # weights of 1 each + self.assertNear(535521/4.5, m.result().numpy(), 0.001) + + def testAccuracy(self): + m = metrics.Accuracy() + m([0, 1, 2, 3], [0, 0, 0, 0]) # 1 correct + m([4], [4]) # 1 correct + m([5], [0]) # 0 correct + m([6], [6]) # 1 correct + m([7], [2]) # 0 correct + self.assertEqual(3.0/8, m.result().numpy()) + + def testWeightedAccuracy(self): + m = metrics.Accuracy() + # 1 correct, total weight of 2 + m([0, 1, 2, 3], [0, 0, 0, 0], weights=[1, 1, 0, 0]) + m([4], [4], weights=[0.5]) # 1 correct with a weight of 0.5 + m([5], [0], weights=[0.5]) # 0 correct, weight 0.5 + m([6], [6]) # 1 correct, weight 1 + m([7], [2]) # 0 correct, weight 1 + self.assertEqual(2.5/5, m.result().numpy()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py new file mode 100644 index 0000000000000000000000000000000000000000..bebc595df07dbd3a0ecfe5c93749f13332805539 --- /dev/null +++ b/tensorflow/contrib/eager/python/network.py @@ -0,0 +1,199 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A Network is a composition of Layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import uuid + +import six + +from tensorflow.python.framework import ops +from tensorflow.python.layers import base +from tensorflow.python.ops import variable_scope + + +class Network(base.Layer): + """Represents the composition of a set of Layers. + + TODO(josh11b,ashankar): + - Should "trainable" be changeable on the Network object? + - Do we allow add_variable in Network? + - Layer.name and Layer.variables.names are not in sync today + d = tf.layers.Dense(1) + d(tf.constant([[1.]])) + print(d.name) + print(d.variables) + - Note that name provided to __init__ is only for error messages? + - Detect layers used in __call__ that weren't registered with add_layer. + - Convert inputs to __call__ to tensors. + - Prevent variables from being created after the first __call__? + (Think about restoring from a checkpoint). + - Save & restore + """ + + def __init__(self, name=None): + super(Network, self).__init__(name=name) + self._container = uuid.uuid4().hex + self._layers = collections.OrderedDict() + + def add_layer(self, layer): + """Add a Layer to this Network. + + `Network` requires that all `Layer`s used in `call()` be added so that the + `Network` can export a complete list of variables. + + Args: + layer: A `tf.layers.Layer` object. + + Returns: + The passed in `layer`. + + Raises: + RuntimeError: If __init__ has not been called. + TypeError: If layer is the wrong type. + ValueError: If a layer with the same name has already been added. + """ + if not hasattr(self, "_layers"): + raise RuntimeError("Need to call Network.__init__ before adding layers") + if not isinstance(layer, base.Layer): + raise TypeError( + "Network.add_layer() passed type %s, not a tf.layers.Layer" % + (type(layer),)) + if layer.name in self._layers: + if self._layers[layer.name] is layer: + return layer + raise ValueError( + "Attempt to add two Layers with the name '%s' to the same Network " + "'%s'" % (layer.name, self.name)) + self._layers[layer.name] = layer + return layer + + def get_layer(self, name=None, index=None): + """Get a contained `tf.layers.Layer` either by name or index. + + Args: + name: String matching one of the names of a contained `Layer`. + index: Integer in [0, number of layers). Layers are assigned an index + by the order they are added. + + Returns: + A `tf.layers.Layer` object. + + Raises: + ValueError: If neither or both of 'index' or 'name' is specified. + """ + if index is not None: + if name is not None: + raise ValueError("Exactly one of 'index' or 'name' must be provided") + if len(self._layers) <= index: + raise ValueError("Was asked to retrieve layer at index " + + str(index) + " but model only has " + str( + len(self._layers)) + " layers.") + return list(self._layers.values())[index] + if name is None: + raise ValueError("Exactly one of 'index' or 'name' must be provided") + return self._layers[index] + + # The following methods are for implementing the Layer interface. + + @property + def weights(self): + # TODO(josh11b): Should this return a set or perform de-duplication of + # variables in the case of shared layers/variables that appear in + # multiple places in the Network? + weights = [] + for layer in six.itervalues(self._layers): + weights += layer.weights + return weights + + @property + def trainable_weights(self): + weights = [] + for layer in six.itervalues(self._layers): + weights += layer.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for layer in six.itervalues(self._layers): + weights += layer.non_trainable_weights + return weights + + @property + def trainable(self): + return True + + @trainable.setter + def trainable(self, value): + if not value: + # We believe it better to decide which layers & networks are trainable + # at the Trainer level than here. Otherwise you can run into trouble if a + # layer/network is shared between two models, but is trainable in one + # but not the other (like with adversarial networks). + raise AttributeError("cannot mark Network as not trainable") + + @property + def layers(self): + return self._layers.values() + + def add_variable(self, name, shape, dtype=None, initializer=None, + regularizer=None, trainable=True, constraint=None): + raise RuntimeError( + "add_variable not supported in Network class yet. Please file an issue " + "at https://github.com/tensorflow/tensorflow/issues/new if this is " + "important to you") + + def __call__(self, inputs, *args, **kwargs): + # TODO(josh11b,ashankar,agarwal): Can we reduce the number of context + # managers here and/or move some of the work into the constructor + # for performance reasons? + with ops.container(self._container): + with variable_scope.variable_scope(variable_scope.get_variable_scope(), + use_resource=True): + return super(Network, self).__call__(inputs, *args, **kwargs) + + # TODO(josh11b): Support other Layer methods needed for graph mode, such as for + # losses and updates + + +class Sequential(Network): + """Represents a linear sequence of Layers. + + The output of each layer is provided as the input to the next. + The inputs passed to `__call__` are passed to the inputs of the first + Layer, and it returns the outputs of the last Layer. + + Args: + layers: An optional sequence of tf.layers.Layer objects. + name: An optional string name to use for this Network. + """ + + def __init__(self, layers=None, name=None): + super(Sequential, self).__init__(name=name) + if layers: + for l in layers: + self.add_layer(l) + + def call(self, inputs): + """Call each Layer in the order they were added.""" + # TODO(josh11b): Support "mode" and maybe other arguments + for l in self.layers: + inputs = l(inputs) + return inputs diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f0dcae85ee139405784a70c2d3704b0bbcf9e4dd --- /dev/null +++ b/tensorflow/contrib/eager/python/network_test.py @@ -0,0 +1,107 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.eager.python import network +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.layers import core + + +# pylint: disable=not-callable +class MyNetwork(network.Network): + + def __init__(self): + super(MyNetwork, self).__init__(name="abcd") + self.l1 = self.add_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.l1(x) + + +class NetworkTest(test.TestCase): + + def testTrainableAttribute(self): + net = network.Network() + self.assertTrue(net.trainable) + with self.assertRaises(AttributeError): + net.trainable = False + self.assertTrue(net.trainable) + + def testNetworkCall(self): + net = MyNetwork() + net(constant_op.constant([[2.0]])) # Force variables to be created. + self.assertEqual(1, len(net.trainable_variables)) + net.trainable_variables[0].assign([[17.0]]) + # TODO(josh11b): Support passing Python values to networks. + result = net(constant_op.constant([[2.0]])) + self.assertEqual(34.0, result.numpy()) + + def testNetworkAsAGraph(self): + self.skipTest("TODO(ashankar,josh11b): FIX THIS") + # Verify that we're using ResourceVariables + + def testNetworkVariablesDoNotInterfere(self): + self.skipTest("TODO: FIX THIS") + net1 = MyNetwork() + net2 = MyNetwork() + + one = constant_op.constant([[1.]]) + + print(type(net1(one))) + net2(one) + + net1.trainable_weights[0].assign(constant_op.constant([[1.]])) + net2.trainable_weights[0].assign(constant_op.constant([[2.]])) + + print("NET1") + print(net1.name) + print(net1.variables) + print(net1(one)) + + print("NET2") + print(net2.name) + print(net2.variables) + print(net2(one)) + + +class SequentialTest(test.TestCase): + + def testTwoLayers(self): + # Create a sequential network with one layer. + net = network.Sequential([core.Dense(1, use_bias=False)]) + + # Set that layer's weights so it multiplies by 3 + l1 = net.get_layer(index=0) + net(constant_op.constant([[2.0]])) # Create l1's variables + self.assertEqual(1, len(l1.trainable_variables)) + l1.trainable_variables[0].assign([[3.0]]) + self.assertEqual(21.0, net(constant_op.constant([[7.0]])).numpy()) + + # Add a second layer to the network. + l2 = core.Dense(1, use_bias=False) + net.add_layer(l2) + + # Set the second layer's weights so it multiplies by 11 + net(constant_op.constant([[2.0]])) # Create l2's variables + self.assertEqual(1, len(l2.trainable_variables)) + l2.trainable_variables[0].assign([[11.0]]) + self.assertEqual(231.0, net(constant_op.constant([[7.0]])).numpy()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index 8edd4b816397cf4ba5f7d43b78f6e50ee6619da1..2bf11d3f208014650909927bd794916eaba8a336 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -19,7 +19,9 @@ from __future__ import print_function import contextlib +from tensorflow.python.eager import context from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver as _saver @@ -41,21 +43,75 @@ def _init_from_checkpoint(self, *args, **kwargs): # pylint: enable=protected-access +@contextlib.contextmanager +def restore_variables_on_create(save_path): + """ContextManager that restores variables on creation. + + When save_path is None (e.g. No checkpoint), does nothing. + Otherwise, it preloads all values from checkpoint. When the + corresponding variable is first created, it assigns the checkpoint + value to the variable. + + ```python + with restore_variables_on_create( + tf.train.latest_checkpoint(checkpoint_dir)): + ``` + + Args: + save_path: The checkpoint file prefix. + + Yields: + Nothing. + + Raises: + NotFoundError: If the variable is not found in checkpoint. + ValueError: If not used in eager mode. + """ + if context.in_graph_mode(): + raise ValueError( + "Currently, restore_variables_on_create can only be used with " + "eager execution enabled.") + if save_path: + ckpt_var_cache = dict() + reader = checkpoint_utils.load_checkpoint(save_path) + for k, _ in checkpoint_utils.list_variables(save_path): + ckpt_var_cache[k] = reader.get_tensor(k) + + old_init = getattr( + resource_variable_ops.ResourceVariable, "_init_from_args", None) + assert old_init, "ResourceVariable misses _init_from_args method." + setattr(resource_variable_ops.ResourceVariable, "_init_from_args", + _init_from_checkpoint) + setattr(resource_variable_ops.ResourceVariable, "old_init", old_init) + setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", + ckpt_var_cache) + try: + yield + except Exception as e: + raise e + finally: + if save_path: + setattr(resource_variable_ops.ResourceVariable, "_init_from_args", + old_init) + setattr(resource_variable_ops.ResourceVariable, "old_init", None) + setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", None) + + class Saver(object): """A simple tf.train.Saver adapter for eager mode. save and restore API are similar to the tf.train.Saver, except that session is not needed. - restore_on_create is eager mode's way to reload checkpoint value during - the execution. (unlike graph mode's reload before run). - Args: - var_list: See tf.train.Saver. Works the same for save/restore. Ignored - by restore_on_create. + var_list: A list of variables. """ - def __init__(self, var_list=None): + def __init__(self, var_list): + if context.in_graph_mode(): + raise ValueError("Currently, tfe.Saver can only be used when eager " + "execution is enabled. Use tf.train.Saver when " + "building graphs.") self._saver = _saver.Saver(var_list=var_list) def save(self, save_path, global_step=None): @@ -68,8 +124,9 @@ class Saver(object): Returns: See save method in tf.train.Saver. """ - return self._saver.save(None, save_path, write_meta_graph=False, - global_step=global_step) + with ops.device("/device:CPU:0"): + return self._saver.save(None, save_path, write_meta_graph=False, + global_step=global_step) def restore(self, save_path): """Restores previously saved variables. @@ -77,47 +134,6 @@ class Saver(object): Args: save_path: See restore method in tf.train.Saver. """ - self._saver.restore(None, save_path) - - @contextlib.contextmanager - def maybe_restore_on_create(self, save_path): - """ContextManager that restores variables on creation. - - When save_path is None (e.g. No checkpoint), does nothing. - Otherwise, it preloads all values from checkpoint. When the - corresponding variable is first created, it assigns the checkpoint - value to the variable. - - Args: - save_path: Same as save_path of retore. If None, do not restore. - - Yields: - Nothing. + with ops.device("/device:CPU:0"): + self._saver.restore(None, save_path) - Raises: - NotFoundError: If the variable is not found in checkpoint. - """ - if save_path: - ckpt_var_cache = dict() - reader = checkpoint_utils.load_checkpoint(save_path) - for k, _ in checkpoint_utils.list_variables(save_path): - ckpt_var_cache[k] = reader.get_tensor(k) - - old_init = getattr( - resource_variable_ops.ResourceVariable, "_init_from_args", None) - assert old_init, "ResourceVariable misses _init_from_args method." - setattr(resource_variable_ops.ResourceVariable, "_init_from_args", - _init_from_checkpoint) - setattr(resource_variable_ops.ResourceVariable, "old_init", old_init) - setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", - ckpt_var_cache) - try: - yield - except Exception as e: - raise e - finally: - if save_path: - setattr(resource_variable_ops.ResourceVariable, "_init_from_args", - old_init) - setattr(resource_variable_ops.ResourceVariable, "old_init", None) - setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", None) diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 9c8294e3bacc2c6fe2689d81cdf6efa7f8ddbc4b..29af2b531f4dee7f46c1538ff23409ece5785ceb 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -21,17 +21,24 @@ import os from tensorflow.contrib.eager.python import saver as _saver from tensorflow.python.eager import context +from tensorflow.python.eager import graph_callable +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test class SaverTest(test.TestCase): + def _dev(self): + return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0' + def testBasics(self): - with context.eager_mode(): + with context.eager_mode(), ops.device(self._dev()): v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') def model(): return array_ops.constant(2.0) * v1 @@ -48,7 +55,7 @@ class SaverTest(test.TestCase): self.assertEqual(v1.read_value().numpy(), 1.0) def testRestoreOnCreate(self): - with context.eager_mode(): + with context.eager_mode(), ops.device(self._dev()): def model(init_val): v1 = resource_variable_ops.ResourceVariable(init_val, name='v1') return array_ops.constant(1.0) * v1, v1 @@ -60,7 +67,7 @@ class SaverTest(test.TestCase): with ops.Graph().as_default(): saver = _saver.Saver([v1]) - with saver.maybe_restore_on_create(ckpt_prefix): + with _saver.restore_variables_on_create(ckpt_prefix): # Value is from checkpoint, but not from argument. ret, _ = model(2.0) self.assertEqual(ret.numpy(), 1.0) @@ -69,7 +76,7 @@ class SaverTest(test.TestCase): self.assertEqual(v1_2.read_value().numpy(), 3.0) def testRestoreNotFound(self): - with context.eager_mode(): + with context.eager_mode(), ops.device(self._dev()): def model(v): return array_ops.constant(1.0) * v @@ -81,9 +88,56 @@ class SaverTest(test.TestCase): with self.assertRaisesRegexp(errors.NotFoundError, 'v2 not found in checkpoint'): - with saver.maybe_restore_on_create(ckpt_prefix): + with _saver.restore_variables_on_create(ckpt_prefix): _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2')) + def testSaveRestoreGraphCallable(self): + with context.eager_mode(), ops.device(self._dev()): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + # Default 2 + 0 = 2 + self.assertEqual( + 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # Save the variable value 0. + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + _saver.Saver(model.variables).save(ckpt_prefix) + + # update variable to 1, so that 2 + 1 = 3 + model.variables[0].assign(1.) + self.assertEqual( + 3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # load the variable value 0, so that 2 + 0 = 2 + _saver.Saver(model.variables).restore(ckpt_prefix) + self.assertEqual( + 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # update checkpoint variable to 1 and memory value to 2. + model.variables[0].assign(1.) + _saver.Saver(model.variables).save(ckpt_prefix) + model.variables[0].assign(2.) + self.assertEqual( + 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # reset the graph and reload on create, so that 1 + 2 = 3 + with ops.Graph().as_default(): + with _saver.restore_variables_on_create(ckpt_prefix): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model2(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + self.assertEqual( + 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index f459e524bc28c0835371bf9ea01fb246fb9a6c62..1acb1ba1b8c2aa2af0f7f24bd37b5afea09fe74f 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -45,7 +45,13 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@Iterator @@Saver @@SummaryWriter +@@restore_variables_on_create @@Variable + +@@in_eager_mode +@@in_graph_mode + +@@run_test_in_graph_and_eager_modes """ from __future__ import absolute_import @@ -56,24 +62,28 @@ from __future__ import print_function # pylint:disable=g-bad-import-order,g-import-not-at-top,unused-import # from tensorflow.contrib.eager.python.datasets import Iterator +from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver from tensorflow.contrib.eager.python.summary_writer import SummaryWriter -from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.eager import backprop -from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager import function from tensorflow.python.eager.context import enable_eager_execution +from tensorflow.python.eager.context import in_eager_mode +from tensorflow.python.eager.context import in_graph_mode from tensorflow.python.eager.context import list_devices from tensorflow.python.eager.context import num_gpus from tensorflow.python.eager.context import run from tensorflow.python.eager.core import enable_tracing +from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr +from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable +from tensorflow.python.util.all_util import remove_undocumented defun = function.defun implicit_gradients = backprop.implicit_grad diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index ac2f388a856d90baf4567a1a9b7dbc55a181e5c5..3d57a98a2ee068281b0934484994e113989e75ce 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -67,8 +67,7 @@ class TFETest(test_util.TensorFlowTestCase): return y, grad_fn - # TODO(ashankar): This [0] should ideally not be needed. - grad = tfe.gradients_function(f, [0]) + grad = tfe.gradients_function(f) self.assertEquals([12], [x.numpy() for x in grad(3)]) def testGPU(self): diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 596f68844b3628d7101fe16e095db7b5160d5baf..4dd9f19ec3123112ac2dd3a6f2db0da90492a234 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -50,7 +50,10 @@ py_test( size = "small", srcs = ["python/estimator/dnn_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + "notsan", + ], deps = [ ":dnn", ":head", @@ -89,7 +92,7 @@ py_library( py_test( name = "extenders_test", - size = "small", + size = "medium", srcs = ["python/estimator/extenders_test.py"], srcs_version = "PY2AND3", deps = [ @@ -146,6 +149,7 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 9b14622ff6436efcf66dae311f773c8375b2cafa..f8648fe5bf11e88d1fc16056d575a2a70290df0b 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -59,7 +59,7 @@ def multi_class_head(n_classes, `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for multi class classification. @@ -98,7 +98,7 @@ def binary_classification_head( `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for binary classification. @@ -129,7 +129,7 @@ def regression_head(weight_column=None, of the last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for linear regression. @@ -172,7 +172,7 @@ def multi_label_head(n_classes, string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for multi-label classification. @@ -227,6 +227,12 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access return self._n_classes def _process_labels(self, labels): + if labels is None: + raise ValueError( + 'You must provide a labels Tensor. Given: None. ' + 'Suggested troubleshooting steps: Check that your data contain ' + 'your label feature. Check that your input_fn properly parses and ' + 'returns labels.') if isinstance(labels, sparse_tensor.SparseTensor): if labels.dtype == dtypes.string: label_ids_values = lookup_ops.index_table_from_tensor( @@ -266,7 +272,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): """See `Head`.""" - with ops.name_scope('head'): + with ops.name_scope(self._name, 'head'): logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access # Predict. diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 9dd9e433277304b320ac17d6478383531f114806..dcbe62b49730baf6d9a98f49e71e9877b185aabb 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -261,6 +262,18 @@ class MultiLabelHead(test.TestCase): actual_unweighted_loss.eval( {labels_placeholder: np.array([1, 1], dtype=np.int64)}) + def test_eval_labels_none(self): + """Tests that error is raised when labels is None.""" + head = head_lib.multi_label_head(n_classes=2) + + with self.assertRaisesRegexp( + ValueError, r'You must provide a labels Tensor\. Given: None\.'): + head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=None) + def _test_eval(self, head, logits, labels, expected_loss, expected_metrics): spec = head.create_estimator_spec( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -504,6 +517,22 @@ class MultiLabelHead(test.TestCase): self.assertAllClose( expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4) + def test_train_labels_none(self): + """Tests that error is raised when labels is None.""" + head = head_lib.multi_label_head(n_classes=2) + def _no_op_train_fn(loss): + del loss + return control_flow_ops.no_op() + + with self.assertRaisesRegexp( + ValueError, r'You must provide a labels Tensor\. Given: None\.'): + head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=None, + train_op_fn=_no_op_train_fn) + def _test_train(self, head, logits, labels, expected_loss): expected_train_result = 'my_train_op' def _train_op_fn(loss): diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index c468c544d372e8bfd6adfa49a58e9bf6c5ef0a8b..44095bd00a7a098a8a89ba4d25c68a2484c00a6e 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -8,6 +8,7 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) +load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") @@ -23,6 +24,7 @@ tf_custom_op_py_library( "python/ops/factorization_ops.py", "python/ops/gmm.py", "python/ops/gmm_ops.py", + "python/ops/kmeans.py", "python/ops/wals.py", ], dso = [ @@ -199,6 +201,29 @@ tf_py_test( ) # Estimators tests +py_test( + name = "kmeans_test", + size = "medium", + srcs = ["python/ops/kmeans_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/67512932 + deps = [ + ":factorization_py", + ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_benchmark", + "//tensorflow/python:random_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "wals_test", size = "large", @@ -221,7 +246,6 @@ tf_py_test( "manual", "noasan", # times out b/63678675 "nomsan", - "notsan", ], ) diff --git a/tensorflow/contrib/factorization/__init__.py b/tensorflow/contrib/factorization/__init__.py index 486c2ea9336d19fb7273d02502f9865adc6aefed..6112c9d8300fe219c8e172a5b70e4ce4cad04eb6 100644 --- a/tensorflow/contrib/factorization/__init__.py +++ b/tensorflow/contrib/factorization/__init__.py @@ -23,22 +23,24 @@ from tensorflow.contrib.factorization.python.ops.clustering_ops import * from tensorflow.contrib.factorization.python.ops.factorization_ops import * from tensorflow.contrib.factorization.python.ops.gmm import * from tensorflow.contrib.factorization.python.ops.gmm_ops import * +from tensorflow.contrib.factorization.python.ops.kmeans import * from tensorflow.contrib.factorization.python.ops.wals import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'KMeans', 'COSINE_DISTANCE', - 'KMEANS_PLUS_PLUS_INIT', - 'RANDOM_INIT', - 'SQUARED_EUCLIDEAN_DISTANCE', - 'WALSModel', 'GMM', 'gmm', 'GmmAlgorithm', + 'KMeans', + 'KMEANS_PLUS_PLUS_INIT', + 'KMeansClustering', + 'RANDOM_INIT', + 'SQUARED_EUCLIDEAN_DISTANCE', 'WALSMatrixFactorization', + 'WALSModel', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/factorization/examples/mnist.py b/tensorflow/contrib/factorization/examples/mnist.py index 9eefbccd4dadf37bbaf728b44007d642fc1439e4..06a62db0049d5c698b2378d207e3657fe66acdbd 100644 --- a/tensorflow/contrib/factorization/examples/mnist.py +++ b/tensorflow/contrib/factorization/examples/mnist.py @@ -142,7 +142,7 @@ def inference(inp, num_clusters, hidden1_units, hidden2_units): # initial_clusters=tf.contrib.factorization.KMEANS_PLUS_PLUS_INIT, use_mini_batch=True) - (all_scores, _, clustering_scores, _, _, kmeans_init, + (all_scores, _, clustering_scores, _, kmeans_init, kmeans_training_op) = kmeans.training_graph() # Some heuristics to approximately whiten this output. all_scores = (all_scores[0] - 0.5) * 5 diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index e5c918066217371b076aa23c2e28650608f93fb0..d7320aeb3def08d23a256dcfee242bb4ecd9b6bd 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -51,6 +51,9 @@ COSINE_DISTANCE = 'cosine' RANDOM_INIT = 'random' KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' +# The name of the variable holding the cluster centers. Used by the Estimator. +CLUSTERS_VAR_NAME = 'clusters' + class KMeans(object): """Creates the graph for k-means clustering.""" @@ -279,7 +282,7 @@ class KMeans(object): """ init_value = array_ops.constant([], dtype=dtypes.float32) cluster_centers = variable_scope.variable( - init_value, name='clusters', validate_shape=False) + init_value, name=CLUSTERS_VAR_NAME, validate_shape=False) cluster_centers_initialized = variable_scope.variable( False, dtype=dtypes.bool, name='initialized') @@ -337,7 +340,6 @@ class KMeans(object): assigned cluster instead. cluster_centers_initialized: scalar indicating whether clusters have been initialized. - cluster_centers_var: a Variable holding the cluster centers. init_op: an op to initialize the clusters. training_op: an op that runs an iteration of training. """ @@ -381,7 +383,7 @@ class KMeans(object): inputs, num_clusters, cluster_idx, cluster_centers_var) return (all_scores, cluster_idx, scores, cluster_centers_initialized, - cluster_centers_var, init_op, training_op) + init_op, training_op) def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var, cluster_centers_updated, total_counts): diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5413fc3f2642443621b33d325e3d8c893fd6ac --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -0,0 +1,397 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A canned Estimator for k-means clustering.""" + +# TODO(ccolby): Move clustering_ops.py into this file and streamline the code. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from tensorflow.contrib.factorization.python.ops import clustering_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.summary import summary +from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util + + +class _LossRelativeChangeHook(session_run_hook.SessionRunHook): + """Stops when the change in loss goes below a tolerance.""" + + def __init__(self, loss_tensor, tolerance): + """Creates a _LossRelativeChangeHook. + + Args: + loss_tensor: A scalar tensor of the loss value. + tolerance: A relative tolerance of loss change between iterations. + """ + self._loss_tensor = loss_tensor + self._tolerance = tolerance + self._prev_loss = None + + def before_run(self, run_context): + del run_context # unused + return session_run_hook.SessionRunArgs(self._loss_tensor) + + def after_run(self, run_context, run_values): + loss = run_values.results + assert loss is not None + if self._prev_loss: + relative_change = (abs(loss - self._prev_loss) / + (1 + abs(self._prev_loss))) + if relative_change < self._tolerance: + run_context.request_stop() + self._prev_loss = loss + + +class _InitializeClustersHook(session_run_hook.SessionRunHook): + """Initializes the cluster centers. + + The chief repeatedly invokes an initialization op until all cluster centers + are initialized. The workers wait for the initialization phase to complete. + """ + + def __init__(self, init_op, is_initialized_var, is_chief): + """Creates an _InitializeClustersHook. + + Args: + init_op: An op that, when run, will choose some initial cluster centers. + This op may need to be run multiple times to choose all the centers. + is_initialized_var: A boolean variable reporting whether all initial + centers have been chosen. + is_chief: A boolean specifying whether this task is the chief. + """ + self._init_op = init_op + self._is_initialized_var = is_initialized_var + self._is_chief = is_chief + + def after_create_session(self, session, coord): + del coord # unused + assert self._init_op.graph is ops.get_default_graph() + assert self._is_initialized_var.graph is self._init_op.graph + while True: + try: + if session.run(self._is_initialized_var): + break + elif self._is_chief: + session.run(self._init_op) + else: + time.sleep(1) + except RuntimeError as e: + logging.info(e) + + +def _parse_tensor_or_dict(features): + """Helper function to convert the input points into a usable format. + + Args: + features: The input points. + + Returns: + If `features` is a dict of `k` features, each of which is a vector of `n` + scalars, the return value is a Tensor of shape `(n, k)` representing `n` + input points, where the items in the `k` dimension are sorted + lexicographically by `features` key. If `features` is not a dict, it is + returned unmodified. + """ + if isinstance(features, dict): + keys = sorted(features.keys()) + with ops.colocate_with(features[keys[0]]): + features = array_ops.concat([features[k] for k in keys], axis=1) + return features + + +class _ModelFn(object): + """Model function for the estimator.""" + + def __init__(self, num_clusters, initial_clusters, distance_metric, + random_seed, use_mini_batch, mini_batch_steps_per_iteration, + kmeans_plus_plus_num_retries, relative_tolerance): + self._num_clusters = num_clusters + self._initial_clusters = initial_clusters + self._distance_metric = distance_metric + self._random_seed = random_seed + self._use_mini_batch = use_mini_batch + self._mini_batch_steps_per_iteration = mini_batch_steps_per_iteration + self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self._relative_tolerance = relative_tolerance + + def model_fn(self, features, mode, config): + """Model function for the estimator. + + Note that this does not take a `1abels` arg. This works, but `input_fn` must + return either `features` or, equivalently, `(features, None)`. + + Args: + features: The input points. See @{tf.estimator.Estimator}. + mode: See @{tf.estimator.Estimator}. + config: See @{tf.estimator.Estimator}. + + Returns: + A @{tf.estimator.EstimatorSpec} (see @{tf.estimator.Estimator}) specifying + this behavior: + * `train_op`: Execute one mini-batch or full-batch run of Lloyd's + algorithm. + * `loss`: The sum of the squared distances from each input point to its + closest center. + * `eval_metric_ops`: Maps `SCORE` to `loss`. + * `predictions`: Maps `ALL_DISTANCES` to the distance from each input + point to each cluster center; maps `CLUSTER_INDEX` to the index of + the closest cluster center for each input point. + """ + # input_points is a single Tensor. Therefore, the sharding functionality + # in clustering_ops is unused, and some of the values below are lists of a + # single item. + input_points = _parse_tensor_or_dict(features) + + # Let N = the number of input_points. + # all_distances: A list of one matrix of shape (N, num_clusters). Each value + # is the distance from an input point to a cluster center. + # model_predictions: A list of one vector of shape (N). Each value is the + # cluster id of an input point. + # losses: Similar to cluster_idx but provides the distance to the cluster + # center. + # is_initialized: scalar indicating whether the initial cluster centers + # have been chosen; see init_op. + # cluster_centers_var: a Variable containing the cluster centers. + # init_op: an op to choose the initial cluster centers. A single worker + # repeatedly executes init_op until is_initialized becomes True. + # training_op: an op that runs an iteration of training, either an entire + # Lloyd iteration or a mini-batch of a Lloyd iteration. Multiple workers + # may execute this op, but only after is_initialized becomes True. + (all_distances, model_predictions, losses, is_initialized, init_op, + training_op) = clustering_ops.KMeans( + inputs=input_points, + num_clusters=self._num_clusters, + initial_clusters=self._initial_clusters, + distance_metric=self._distance_metric, + use_mini_batch=self._use_mini_batch, + mini_batch_steps_per_iteration=self._mini_batch_steps_per_iteration, + random_seed=self._random_seed, + kmeans_plus_plus_num_retries=self._kmeans_plus_plus_num_retries + ).training_graph() + + loss = math_ops.reduce_sum(losses) + summary.scalar('loss/raw', loss) + + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) + training_op = control_flow_ops.with_dependencies([training_op, incr_step], + loss) + + training_hooks = [ + _InitializeClustersHook(init_op, is_initialized, config.is_chief) + ] + if self._relative_tolerance is not None: + training_hooks.append( + _LossRelativeChangeHook(loss, self._relative_tolerance)) + + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions={ + KMeansClustering.ALL_DISTANCES: all_distances[0], + KMeansClustering.CLUSTER_INDEX: model_predictions[0], + }, + loss=loss, + train_op=training_op, + eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)}, + training_hooks=training_hooks) + + +# TODO(agarwal,ands): support sharded input. +class KMeansClustering(estimator.Estimator): + """An Estimator for K-Means clustering.""" + + # Valid values for the distance_metric constructor argument. + SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE + COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE + + # Values for initial_clusters constructor argument. + RANDOM_INIT = clustering_ops.RANDOM_INIT + KMEANS_PLUS_PLUS_INIT = clustering_ops.KMEANS_PLUS_PLUS_INIT + + # Metric returned by evaluate(): The sum of the squared distances from each + # input point to its closest center. + SCORE = 'score' + + # Keys returned by predict(). + # ALL_DISTANCES: The distance from each input point to each cluster center. + # CLUSTER_INDEX: The index of the closest cluster center for each input point. + CLUSTER_INDEX = 'cluster_index' + ALL_DISTANCES = 'all_distances' + + def __init__(self, + num_clusters, + model_dir=None, + initial_clusters=RANDOM_INIT, + distance_metric=SQUARED_EUCLIDEAN_DISTANCE, + random_seed=0, + use_mini_batch=True, + mini_batch_steps_per_iteration=1, + kmeans_plus_plus_num_retries=2, + relative_tolerance=None, + config=None): + """Creates an Estimator for running KMeans training and inference. + + This Estimator implements the following variants of the K-means algorithm: + + If `use_mini_batch` is False, it runs standard full batch K-means. Each + training step runs a single iteration of K-Means and must process the full + input at once. To run in this mode, the `input_fn` passed to `train` must + return the entire input dataset. + + If `use_mini_batch` is True, it runs a generalization of the mini-batch + K-means algorithm. It runs multiple iterations, where each iteration is + composed of `mini_batch_steps_per_iteration` steps. Each training step + accumulates the contribution from one mini-batch into temporary storage. + Every `mini_batch_steps_per_iteration` steps, the cluster centers are + updated and the temporary storage cleared for the next iteration. Note + that: + * If `mini_batch_steps_per_iteration=1`, the algorithm reduces to the + standard K-means mini-batch algorithm. + * If `mini_batch_steps_per_iteration = num_inputs / batch_size`, the + algorithm becomes an asynchronous version of the full-batch algorithm. + However, there is no guarantee by this implementation that each input + is seen exactly once per iteration. Also, different updates are applied + asynchronously without locking. So this asynchronous version may not + behave exactly like a full-batch version. + + Args: + num_clusters: An integer tensor specifying the number of clusters. This + argument is ignored if `initial_clusters` is a tensor or numpy array. + model_dir: The directory to save the model results and log files. + initial_clusters: Specifies how the initial cluster centers are chosen. + One of the following: + * a tensor or numpy array with the initial cluster centers. + * a callable `f(inputs, k)` that selects and returns up to `k` centers + from an input batch. `f` is free to return any number of centers + from `0` to `k`. It will be invoked on successive input batches + as necessary until all `num_clusters` centers are chosen. + * `KMeansClustering.RANDOM_INIT`: Choose centers randomly from an input + batch. If the batch size is less than `num_clusters` then the + entire batch is chosen to be initial cluster centers and the + remaining centers are chosen from successive input batches. + * `KMeansClustering.KMEANS_PLUS_PLUS_INIT`: Use kmeans++ to choose + centers from the first input batch. If the batch size is less + than `num_clusters`, a TensorFlow runtime error occurs. + distance_metric: The distance metric used for clustering. One of: + * `KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`: Euclidean distance + between vectors `u` and `v` is defined as `||u - v||_2` which is + the square root of the sum of the absolute squares of the elements' + difference. + * `KMeansClustering.COSINE_DISTANCE`: Cosine distance between vectors + `u` and `v` is defined as `1 - (u . v) / (||u||_2 ||v||_2)`. + random_seed: Python integer. Seed for PRNG used to initialize centers. + use_mini_batch: A boolean specifying whether to use the mini-batch k-means + algorithm. See explanation above. + mini_batch_steps_per_iteration: The number of steps after which the + updated cluster centers are synced back to a master copy. Used only if + `use_mini_batch=True`. See explanation above. + kmeans_plus_plus_num_retries: For each point that is sampled during + kmeans++ initialization, this parameter specifies the number of + additional points to draw from the current distribution before selecting + the best. If a negative value is specified, a heuristic is used to + sample `O(log(num_to_sample))` additional points. Used only if + `initial_clusters=KMeansClustering.KMEANS_PLUS_PLUS_INIT`. + relative_tolerance: A relative tolerance of change in the loss between + iterations. Stops learning if the loss changes less than this amount. + This may not work correctly if `use_mini_batch=True`. + config: See @{tf.estimator.Estimator}. + + Raises: + ValueError: An invalid argument was passed to `initial_clusters` or + `distance_metric`. + """ + if isinstance(initial_clusters, str) and initial_clusters not in [ + KMeansClustering.RANDOM_INIT, KMeansClustering.KMEANS_PLUS_PLUS_INIT + ]: + raise ValueError( + "Unsupported initialization algorithm '%s'" % initial_clusters) + if distance_metric not in [ + KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + KMeansClustering.COSINE_DISTANCE + ]: + raise ValueError("Unsupported distance metric '%s'" % distance_metric) + super(KMeansClustering, self).__init__( + model_fn=_ModelFn( + num_clusters, initial_clusters, distance_metric, random_seed, + use_mini_batch, mini_batch_steps_per_iteration, + kmeans_plus_plus_num_retries, relative_tolerance).model_fn, + model_dir=model_dir, + config=config) + + def _predict_one_key(self, input_fn, predict_key): + for result in self.predict(input_fn=input_fn, predict_keys=[predict_key]): + yield result[predict_key] + + def predict_cluster_index(self, input_fn): + """Finds the index of the closest cluster center to each input point. + + Args: + input_fn: Input points. See @{tf.estimator.Estimator.predict}. + + Yields: + The index of the closest cluster center for each input point. + """ + for index in self._predict_one_key(input_fn, + KMeansClustering.CLUSTER_INDEX): + yield index + + def score(self, input_fn): + """Returns the sum of squared distances to nearest clusters. + + Note that this function is different from the corresponding one in sklearn + which returns the negative sum. + + Args: + input_fn: Input points. See @{tf.estimator.Estimator.evaluate}. Only one + batch is retrieved. + + Returns: + The sum of the squared distance from each point in the first batch of + inputs to its nearest cluster center. + """ + return self.evaluate(input_fn=input_fn, steps=1)[KMeansClustering.SCORE] + + def transform(self, input_fn): + """Transforms each input point to its distances to all cluster centers. + + Note that if `distance_metric=KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`, + this + function returns the squared Euclidean distance while the corresponding + sklearn function returns the Euclidean distance. + + Args: + input_fn: Input points. See @{tf.estimator.Estimator.predict}. + + Yields: + The distances from each input point to each cluster center. + """ + for distances in self._predict_one_key(input_fn, + KMeansClustering.ALL_DISTANCES): + yield distances + + def cluster_centers(self): + """Returns the cluster centers.""" + return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME) diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4709d7942583f1406a3fa0ff3a078d0283872ea6 --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -0,0 +1,575 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for KMeans.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import time + +import numpy as np +from sklearn.cluster import KMeans as SklearnKMeans + +# pylint: disable=g-import-not-at-top +from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_lib +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import flags +from tensorflow.python.platform import test +from tensorflow.python.training import input as input_lib +from tensorflow.python.training import queue_runner + +FLAGS = flags.FLAGS + + +def normalize(x): + return x / np.sqrt(np.sum(x * x, axis=-1, keepdims=True)) + + +def cosine_similarity(x, y): + return np.dot(normalize(x), np.transpose(normalize(y))) + + +def make_random_centers(num_centers, num_dims, center_norm=500): + return np.round( + np.random.rand(num_centers, num_dims).astype(np.float32) * center_norm) + + +def make_random_points(centers, num_points, max_offset=20): + num_centers, num_dims = centers.shape + assignments = np.random.choice(num_centers, num_points) + offsets = np.round( + np.random.randn(num_points, num_dims).astype(np.float32) * max_offset) + return (centers[assignments] + offsets, assignments, np.add.reduce( + offsets * offsets, 1)) + + +class KMeansTestBase(test.TestCase): + + def input_fn(self, + batch_size=None, + points=None, + randomize=None, + num_epochs=None): + """Returns an input_fn that randomly selects batches from given points.""" + batch_size = batch_size or self.batch_size + points = points if points is not None else self.points + num_points = points.shape[0] + if randomize is None: + randomize = (self.use_mini_batch and + self.mini_batch_steps_per_iteration <= 1) + + def _fn(): + x = constant_op.constant(points) + if batch_size == num_points: + return input_lib.limit_epochs(x, num_epochs=num_epochs), None + if randomize: + indices = random_ops.random_uniform( + constant_op.constant([batch_size]), + minval=0, + maxval=num_points - 1, + dtype=dtypes.int32, + seed=10) + else: + # We need to cycle through the indices sequentially. We create a queue + # to maintain the list of indices. + q = data_flow_ops.FIFOQueue(num_points, dtypes.int32, ()) + + # Conditionally initialize the Queue. + def _init_q(): + with ops.control_dependencies( + [q.enqueue_many(math_ops.range(num_points))]): + return control_flow_ops.no_op() + + init_q = control_flow_ops.cond(q.size() <= 0, _init_q, + control_flow_ops.no_op) + with ops.control_dependencies([init_q]): + offsets = q.dequeue_many(batch_size) + with ops.control_dependencies([q.enqueue_many(offsets)]): + indices = array_ops.identity(offsets) + batch = array_ops.gather(x, indices) + return (input_lib.limit_epochs(batch, num_epochs=num_epochs), None) + + return _fn + + @staticmethod + def config(tf_random_seed): + return run_config.RunConfig().replace(tf_random_seed=tf_random_seed) + + @property + def initial_clusters(self): + return kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT + + @property + def batch_size(self): + return self.num_points + + @property + def use_mini_batch(self): + return False + + @property + def mini_batch_steps_per_iteration(self): + return 1 + + +class KMeansTest(KMeansTestBase): + + def setUp(self): + np.random.seed(3) + self.num_centers = 5 + self.num_dims = 2 + self.num_points = 1000 + self.true_centers = make_random_centers(self.num_centers, self.num_dims) + self.points, _, self.scores = make_random_points(self.true_centers, + self.num_points) + self.true_score = np.add.reduce(self.scores) + + def _kmeans(self, relative_tolerance=None): + return kmeans_lib.KMeansClustering( + self.num_centers, + initial_clusters=self.initial_clusters, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + random_seed=24, + relative_tolerance=relative_tolerance) + + def test_clusters(self): + kmeans = self._kmeans() + kmeans.train(input_fn=self.input_fn(), steps=1) + clusters = kmeans.cluster_centers() + self.assertAllEqual(list(clusters.shape), [self.num_centers, self.num_dims]) + + def test_fit(self): + kmeans = self._kmeans() + kmeans.train(input_fn=self.input_fn(), steps=1) + score1 = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) + steps = 10 * self.num_points // self.batch_size + kmeans.train(input_fn=self.input_fn(), steps=steps) + score2 = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) + self.assertTrue(score1 > score2) + self.assertNear(self.true_score, score2, self.true_score * 0.05) + + def test_monitor(self): + if self.use_mini_batch: + # We don't test for use_mini_batch case since the loss value can be noisy. + return + kmeans = kmeans_lib.KMeansClustering( + self.num_centers, + initial_clusters=self.initial_clusters, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + config=self.config(14), + random_seed=12, + relative_tolerance=1e-4) + + kmeans.train( + input_fn=self.input_fn(), + # Force it to train until the relative tolerance monitor stops it. + steps=None) + score = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) + self.assertNear(self.true_score, score, self.true_score * 0.01) + + def test_infer(self): + kmeans = self._kmeans() + # Make a call to fit to initialize the cluster centers. + max_steps = 1 + kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + clusters = kmeans.cluster_centers() + + # Make a small test set + num_points = 10 + points, true_assignments, true_offsets = make_random_points( + clusters, num_points) + input_fn = self.input_fn(batch_size=num_points, points=points, num_epochs=1) + # Test predict + assignments = list(kmeans.predict_cluster_index(input_fn)) + self.assertAllEqual(assignments, true_assignments) + + # Test score + score = kmeans.score(input_fn=lambda: (constant_op.constant(points), None)) + self.assertNear(score, np.sum(true_offsets), 0.01 * score) + + # Test transform + transform = list(kmeans.transform(input_fn)) + true_transform = np.maximum( + 0, + np.sum(np.square(points), axis=1, keepdims=True) - + 2 * np.dot(points, np.transpose(clusters)) + np.transpose( + np.sum(np.square(clusters), axis=1, keepdims=True))) + self.assertAllClose(transform, true_transform, rtol=0.05, atol=10) + + +class KMeansTestMultiStageInit(KMeansTestBase): + + def test_random(self): + points = np.array( + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]], dtype=np.float32) + kmeans = kmeans_lib.KMeansClustering( + num_clusters=points.shape[0], + initial_clusters=kmeans_lib.KMeansClustering.RANDOM_INIT, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=True, + mini_batch_steps_per_iteration=100, + random_seed=24, + relative_tolerance=None) + kmeans.train( + input_fn=self.input_fn(batch_size=1, points=points, randomize=False), + steps=1) + clusters = kmeans.cluster_centers() + self.assertAllEqual(points, clusters) + + def test_kmeans_plus_plus_batch_just_right(self): + points = np.array([[1, 2]], dtype=np.float32) + kmeans = kmeans_lib.KMeansClustering( + num_clusters=points.shape[0], + initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=True, + mini_batch_steps_per_iteration=100, + random_seed=24, + relative_tolerance=None) + kmeans.train( + input_fn=self.input_fn(batch_size=1, points=points, randomize=False), + steps=1) + clusters = kmeans.cluster_centers() + self.assertAllEqual(points, clusters) + + def test_kmeans_plus_plus_batch_too_small(self): + points = np.array( + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]], dtype=np.float32) + kmeans = kmeans_lib.KMeansClustering( + num_clusters=points.shape[0], + initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=True, + mini_batch_steps_per_iteration=100, + random_seed=24, + relative_tolerance=None) + with self.assertRaisesOpError(AssertionError): + kmeans.train( + input_fn=self.input_fn(batch_size=4, points=points, randomize=False), + steps=1) + + +class MiniBatchKMeansTest(KMeansTest): + + @property + def batch_size(self): + return 50 + + @property + def use_mini_batch(self): + return True + + +class FullBatchAsyncKMeansTest(KMeansTest): + + @property + def batch_size(self): + return 50 + + @property + def use_mini_batch(self): + return True + + @property + def mini_batch_steps_per_iteration(self): + return self.num_points // self.batch_size + + +class KMeansCosineDistanceTest(KMeansTestBase): + + def setUp(self): + self.points = np.array( + [[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2], [0.1, 2.5], [0.2, 2], + [0.1, 3], [0.2, 4]], + dtype=np.float32) + self.num_points = self.points.shape[0] + self.true_centers = np.array( + [ + normalize( + np.mean(normalize(self.points)[0:4, :], axis=0, + keepdims=True))[0], + normalize( + np.mean(normalize(self.points)[4:, :], axis=0, + keepdims=True))[0] + ], + dtype=np.float32) + self.true_assignments = np.array([0] * 4 + [1] * 4) + self.true_score = len(self.points) - np.tensordot( + normalize(self.points), self.true_centers[self.true_assignments]) + + self.num_centers = 2 + self.kmeans = kmeans_lib.KMeansClustering( + self.num_centers, + initial_clusters=kmeans_lib.KMeansClustering.RANDOM_INIT, + distance_metric=kmeans_lib.KMeansClustering.COSINE_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + config=self.config(3)) + + def test_fit(self): + max_steps = 10 * self.num_points // self.batch_size + self.kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + centers = normalize(self.kmeans.cluster_centers()) + centers = centers[centers[:, 0].argsort()] + true_centers = self.true_centers[self.true_centers[:, 0].argsort()] + self.assertAllClose(centers, true_centers, atol=0.04) + + def test_transform(self): + self.kmeans.train(input_fn=self.input_fn(), steps=10) + centers = normalize(self.kmeans.cluster_centers()) + true_transform = 1 - cosine_similarity(self.points, centers) + transform = list( + self.kmeans.transform( + input_fn=self.input_fn(batch_size=self.num_points, num_epochs=1))) + self.assertAllClose(transform, true_transform, atol=1e-3) + + def test_predict(self): + max_steps = 10 * self.num_points // self.batch_size + self.kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + centers = normalize(self.kmeans.cluster_centers()) + + assignments = list( + self.kmeans.predict_cluster_index( + input_fn=self.input_fn(num_epochs=1, batch_size=self.num_points))) + self.assertAllClose( + centers[assignments], + self.true_centers[self.true_assignments], + atol=1e-2) + + centers = centers[centers[:, 0].argsort()] + true_centers = self.true_centers[self.true_centers[:, 0].argsort()] + self.assertAllClose(centers, true_centers, atol=0.04) + score = self.kmeans.score( + input_fn=self.input_fn(batch_size=self.num_points)) + self.assertAllClose(score, self.true_score, atol=1e-2) + + def test_predict_kmeans_plus_plus(self): + # Most points are concetrated near one center. KMeans++ is likely to find + # the less populated centers. + points = np.array( + [[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3], [-3.1, -3.2], + [-2.8, -3.], [-2.9, -3.1], [-3., -3.1], [-3., -3.1], [-3.2, -3.], + [-3., -3.]], + dtype=np.float32) + true_centers = np.array( + [ + normalize( + np.mean(normalize(points)[0:2, :], axis=0, keepdims=True))[0], + normalize( + np.mean(normalize(points)[2:4, :], axis=0, keepdims=True))[0], + normalize(np.mean(normalize(points)[4:, :], axis=0, + keepdims=True))[0] + ], + dtype=np.float32) + true_assignments = [0] * 2 + [1] * 2 + [2] * 8 + true_score = len(points) - np.tensordot( + normalize(points), true_centers[true_assignments]) + + kmeans = kmeans_lib.KMeansClustering( + 3, + initial_clusters=self.initial_clusters, + distance_metric=kmeans_lib.KMeansClustering.COSINE_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + config=self.config(3)) + kmeans.train( + input_fn=lambda: (constant_op.constant(points), None), steps=30) + + centers = normalize(kmeans.cluster_centers()) + self.assertAllClose( + sorted(centers.tolist()), sorted(true_centers.tolist()), atol=1e-2) + + def _input_fn(): + return (input_lib.limit_epochs( + constant_op.constant(points), num_epochs=1), None) + + assignments = list(kmeans.predict_cluster_index(input_fn=_input_fn)) + self.assertAllClose( + centers[assignments], true_centers[true_assignments], atol=1e-2) + + score = kmeans.score(input_fn=lambda: (constant_op.constant(points), None)) + self.assertAllClose(score, true_score, atol=1e-2) + + +class MiniBatchKMeansCosineTest(KMeansCosineDistanceTest): + + @property + def batch_size(self): + return 2 + + @property + def use_mini_batch(self): + return True + + +class FullBatchAsyncKMeansCosineTest(KMeansCosineDistanceTest): + + @property + def batch_size(self): + return 2 + + @property + def use_mini_batch(self): + return True + + @property + def mini_batch_steps_per_iteration(self): + return self.num_points // self.batch_size + + +class KMeansBenchmark(benchmark.Benchmark): + """Base class for benchmarks.""" + + def SetUp(self, + dimension=50, + num_clusters=50, + points_per_cluster=10000, + center_norm=500, + cluster_width=20): + np.random.seed(123456) + self.num_clusters = num_clusters + self.num_points = num_clusters * points_per_cluster + self.centers = make_random_centers( + self.num_clusters, dimension, center_norm=center_norm) + self.points, _, scores = make_random_points( + self.centers, self.num_points, max_offset=cluster_width) + self.score = float(np.sum(scores)) + + def _report(self, num_iters, start, end, scores): + print(scores) + self.report_benchmark( + iters=num_iters, + wall_time=(end - start) / num_iters, + extras={'true_sum_squared_distances': self.score, + 'fit_scores': scores}) + + def _fit(self, num_iters=10): + pass + + def benchmark_01_2dim_5center_500point(self): + self.SetUp(dimension=2, num_clusters=5, points_per_cluster=100) + self._fit() + + def benchmark_02_20dim_20center_10kpoint(self): + self.SetUp(dimension=20, num_clusters=20, points_per_cluster=500) + self._fit() + + def benchmark_03_100dim_50center_50kpoint(self): + self.SetUp(dimension=100, num_clusters=50, points_per_cluster=1000) + self._fit() + + def benchmark_03_100dim_50center_50kpoint_unseparated(self): + self.SetUp( + dimension=100, + num_clusters=50, + points_per_cluster=1000, + cluster_width=250) + self._fit() + + def benchmark_04_100dim_500center_500kpoint(self): + self.SetUp(dimension=100, num_clusters=500, points_per_cluster=1000) + self._fit(num_iters=4) + + def benchmark_05_100dim_500center_500kpoint_unseparated(self): + self.SetUp( + dimension=100, + num_clusters=500, + points_per_cluster=1000, + cluster_width=250) + self._fit(num_iters=4) + + +class TensorflowKMeansBenchmark(KMeansBenchmark): + + def _fit(self, num_iters=10): + scores = [] + start = time.time() + for i in range(num_iters): + print('Starting tensorflow KMeans: %d' % i) + tf_kmeans = kmeans_lib.KMeansClustering( + self.num_clusters, + initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT, + kmeans_plus_plus_num_retries=int(math.log(self.num_clusters) + 2), + random_seed=i * 42, + relative_tolerance=1e-6, + config=self.config(3)) + tf_kmeans.train( + input_fn=lambda: (constant_op.constant(self.points), None), steps=50) + _ = tf_kmeans.cluster_centers() + scores.append( + tf_kmeans.score( + input_fn=lambda: (constant_op.constant(self.points), None))) + self._report(num_iters, start, time.time(), scores) + + +class SklearnKMeansBenchmark(KMeansBenchmark): + + def _fit(self, num_iters=10): + scores = [] + start = time.time() + for i in range(num_iters): + print('Starting sklearn KMeans: %d' % i) + sklearn_kmeans = SklearnKMeans( + n_clusters=self.num_clusters, + init='k-means++', + max_iter=50, + n_init=1, + tol=1e-4, + random_state=i * 42) + sklearn_kmeans.train(self.points) + scores.append(sklearn_kmeans.inertia_) + self._report(num_iters, start, time.time(), scores) + + +class KMeansTestQueues(test.TestCase): + + def input_fn(self): + + def _fn(): + queue = data_flow_ops.FIFOQueue( + capacity=10, dtypes=dtypes.float32, shapes=[10, 3]) + enqueue_op = queue.enqueue(array_ops.zeros([10, 3], dtype=dtypes.float32)) + queue_runner.add_queue_runner( + queue_runner.QueueRunner(queue, [enqueue_op])) + return queue.dequeue(), None + + return _fn + + # This test makes sure that there are no deadlocks when using a QueueRunner. + # Note that since cluster initialization is dependendent on inputs, if input + # is generated using a QueueRunner, one has to make sure that these runners + # are started before the initialization. + def test_queues(self): + kmeans = kmeans_lib.KMeansClustering(5) + kmeans.train(input_fn=self.input_fn(), steps=1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 3e3ee5fa57f1356db98a17f9e17e60f01d85d3b9..3976395d78e9188dd56d5b3b32fa8a3daf43c37d 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -26,7 +26,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -38,31 +37,30 @@ from tensorflow.python.training import session_run_hook class _SweepHook(session_run_hook.SessionRunHook): """Keeps track of row/col sweeps, and runs prep ops before each sweep.""" - def __init__(self, is_row_sweep_var, train_op, num_rows, num_cols, - processed_row_indices, processed_col_indices, row_prep_ops, - col_prep_ops, cache_init_ops, completed_sweeps_var): + def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols, + input_row_indices, input_col_indices, row_prep_ops, + col_prep_ops, init_op, completed_sweeps_var): """Initializes SweepHook. Args: is_row_sweep_var: A Boolean tf.Variable, determines whether we are currently doing a row or column sweep. It is updated by the hook. - train_op: An op. All the ops created by the hook will have - control_dependencies on train_op. + train_ops: A list of ops. The ops created by this hook will have + control dependencies on `train_ops`. num_rows: int, the total number of rows to be processed. num_cols: int, the total number of columns to be processed. - processed_row_indices: A Tensor of type int64. The indices of the input - rows that are processed during the current sweep. All elements of - processed_row_indices must be in [0, num_rows). - processed_col_indices: A Tensor of type int64. The indices of the input + input_row_indices: A Tensor of type int64. The indices of the input rows + that are processed during the current sweep. All elements of + `input_row_indices` must be in [0, num_rows). + input_col_indices: A Tensor of type int64. The indices of the input columns that are processed during the current sweep. All elements of - processed_col_indices must be in [0, num_cols). + `input_col_indices` must be in [0, num_cols). row_prep_ops: list of ops, to be run before the beginning of each row sweep, in the given order. col_prep_ops: list of ops, to be run before the beginning of each column sweep, in the given order. - cache_init_ops: list of ops, to be run once before training, in the given - order. These are typically local initialization ops (such as cache - initialization). + init_op: op to be run once before training. This is typically a local + initialization op (such as cache initialization). completed_sweeps_var: An integer tf.Variable, indicates the number of completed sweeps. It is updated by the hook. """ @@ -70,55 +68,45 @@ class _SweepHook(session_run_hook.SessionRunHook): self._num_cols = num_cols self._row_prep_ops = row_prep_ops self._col_prep_ops = col_prep_ops - self._cache_init_ops = cache_init_ops + self._init_op = init_op self._is_row_sweep_var = is_row_sweep_var self._completed_sweeps_var = completed_sweeps_var - # Boolean variable that determines whether the cache_init_ops have been run. + # Boolean variable that determines whether the init_ops have been run. self._is_initialized = False - # Boolean variable that is set to True when a sweep is completed. - # Used to run the prep_ops at the beginning of a sweep, in before_run(). - self._is_sweep_done = False - # Ops to run jointly with train_op, responsible for updating - # _is_row_sweep_var and incrementing the global_step and completed_sweeps - # counters. They have control_dependencies on train_op. - self._fetches = self._create_switch_ops(processed_row_indices, - processed_col_indices, train_op) - - def _create_switch_ops(self, processed_row_indices, processed_col_indices, - train_op): + # Ops to run jointly with train_ops, responsible for updating + # `is_row_sweep_var` and incrementing the `global_step` and + # `completed_sweeps` counters. + self._update_op, self._is_sweep_done_var, self._switch_op = ( + self._create_hook_ops(input_row_indices, input_col_indices, train_ops)) + + def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops): """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. - Creates two boolean tensors processed_rows and processed_cols, which keep - track of which rows/cols have been processed during the current sweep. + Creates two boolean tensors `processed_rows` and `processed_cols`, which + keep track of which rows/cols have been processed during the current sweep. Returns ops that should be run after each row / col update. - - When is_row_sweep_var is True, it sets - processed_rows[processed_row_indices] to True. - - When is_row_sweep_var is False, it sets - processed_cols[processed_col_indices] to True . - When all rows or all cols have been processed, negates is_row_sweep_var, - increments the completed_sweeps counter, and resets processed_rows and - processed_cols to False. - All of the ops created by this function have control_dependencies on - train_op. + - When `self._is_row_sweep_var` is True, it sets + processed_rows[input_row_indices] to True. + - When `self._is_row_sweep_var` is False, it sets + processed_cols[input_col_indices] to True. Args: - processed_row_indices: A Tensor. The indices of the input rows that are + input_row_indices: A Tensor. The indices of the input rows that are processed during the current sweep. - processed_col_indices: A Tensor. The indices of the input columns that + input_col_indices: A Tensor. The indices of the input columns that are processed during the current sweep. - train_op: An op. All the ops created by this function have - control_dependencies on train_op. + train_ops: A list of ops. The ops created by this function have control + dependencies on `train_ops`. + Returns: - A list consisting of: - is_sweep_done: A Boolean tensor, determines whether the sweep is done, - i.e. all rows (during a row sweep) or all columns (during a column - sweep) have been processed. - switch_ops: An op that updates is_row_sweep_var when is_sweep_done is - True. Has control_dependencies on train_op. - incr_ops: An op that increments the global_step and completed_sweeps - counters. Has control_dependenciens on switch_ops. + A tuple consisting of: + update_op: An op to be run jointly with training. It updates the state + and increments counters (global step and completed sweeps). + is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is + done, i.e. all rows (during a row sweep) or all columns (during a + column sweep) have been processed. + switch_op: An op to be run in `self.before_run` when the sweep is done. """ - processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) with ops.colocate_with(processed_rows_init): processed_rows = variable_scope.variable( @@ -133,97 +121,72 @@ class _SweepHook(session_run_hook.SessionRunHook): collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="sweep_hook_processed_cols") - # After running the train_op, update processed_rows or processed_cols - # tensors, depending on whether we are currently doing a row or a col sweep - with ops.control_dependencies([train_op]): - - def get_row_update_op(): - with ops.colocate_with(processed_rows): - return state_ops.scatter_update(processed_rows, processed_row_indices, - array_ops.ones_like( - processed_row_indices, - dtype=dtypes.bool)) - - def get_col_update_op(): - with ops.colocate_with(processed_cols): - return state_ops.scatter_update(processed_cols, processed_col_indices, - array_ops.ones_like( - processed_col_indices, - dtype=dtypes.bool)) - - update_processed_op = control_flow_ops.cond( - self._is_row_sweep_var, get_row_update_op, get_col_update_op) - - # After update_processed_op, check whether we have completed a sweep. - # If this is the case, flip the is_row_sweep_var and reset processed_rows - # and processed_cols tensors. - with ops.control_dependencies([update_processed_op]): - - def get_switch_op(): - return state_ops.assign( - self._is_row_sweep_var, - gen_math_ops.logical_not(self._is_row_sweep_var)).op - - def get_reset_op(): - return control_flow_ops.group( - state_ops.assign(processed_rows, processed_rows_init).op, - state_ops.assign(processed_cols, processed_cols_init).op) - - is_sweep_done = control_flow_ops.cond( + switch_ops = control_flow_ops.group( + state_ops.assign( self._is_row_sweep_var, - lambda: math_ops.reduce_all(processed_rows), - lambda: math_ops.reduce_all(processed_cols), - name="sweep_hook_is_sweep_done") - switch_op = control_flow_ops.cond( - is_sweep_done, - get_switch_op, - control_flow_ops.no_op, - name="sweep_hook_switch_op") - reset_op = control_flow_ops.cond( - is_sweep_done, - get_reset_op, - control_flow_ops.no_op, - name="sweep_hook_reset_op") - switch_ops = control_flow_ops.group( - switch_op, reset_op, name="sweep_hook_switch_ops") - - with ops.control_dependencies([switch_ops]): - # Op to increment the completed_sweeps counter. - completed_sweeps_incr_op = control_flow_ops.cond( - is_sweep_done, - lambda: state_ops.assign_add(self._completed_sweeps_var, 1).op, - control_flow_ops.no_op, - name="completed_sweeps_incr") - - # Op to increment the global_step counter. - global_step = framework_variables.get_global_step() - if global_step is not None: - global_step_incr_op = state_ops.assign_add( - global_step, 1, name="global_step_incr").op - else: - global_step_incr_op = control_flow_ops.no_op( - name="global_step_incr") - - incr_ops = control_flow_ops.group( - completed_sweeps_incr_op, - global_step_incr_op, - name="counter_incr_ops") - - return [is_sweep_done, switch_ops, incr_ops] + math_ops.logical_not(self._is_row_sweep_var)), + state_ops.assign(processed_rows, processed_rows_init), + state_ops.assign(processed_cols, processed_cols_init)) + is_sweep_done_var = variable_scope.variable( + False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False, + name="is_sweep_done") + + # After running the `train_ops`, updates `processed_rows` or + # `processed_cols` tensors, depending on whether this is a row or col sweep. + with ops.control_dependencies(train_ops): + with ops.colocate_with(processed_rows): + update_processed_rows = state_ops.scatter_update( + processed_rows, + input_row_indices, + math_ops.logical_and( + self._is_row_sweep_var, + array_ops.ones_like(input_row_indices, dtype=dtypes.bool))) + with ops.colocate_with(processed_cols): + update_processed_cols = state_ops.scatter_update( + processed_cols, + input_col_indices, + math_ops.logical_and( + math_ops.logical_not(self._is_row_sweep_var), + array_ops.ones_like(input_col_indices, dtype=dtypes.bool))) + update_processed_op = control_flow_ops.group( + update_processed_rows, update_processed_cols) - def begin(self): - pass + with ops.control_dependencies([update_processed_op]): + is_sweep_done = math_ops.logical_or( + math_ops.reduce_all(processed_rows), + math_ops.reduce_all(processed_cols)) + # Increments global step. + global_step = framework_variables.get_global_step() + if global_step is not None: + global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + global_step_incr_op = control_flow_ops.no_op() + # Increments completed sweeps. + completed_sweeps_incr_op = state_ops.assign_add( + self._completed_sweeps_var, + math_ops.cast(is_sweep_done, dtypes.int32), + use_locking=True).op + update_ops = control_flow_ops.group( + global_step_incr_op, + completed_sweeps_incr_op, + state_ops.assign(is_sweep_done_var, is_sweep_done)) + + return update_ops, is_sweep_done_var, switch_ops def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" - # Run the appropriate cache_init and prep ops + # Runs the appropriate init ops and prep ops. sess = run_context.session + is_sweep_done = sess.run(self._is_sweep_done_var) if not self._is_initialized: - logging.info("SweepHook running cache init ops.") - for init_op in self._cache_init_ops: - sess.run(init_op) - - if self._is_sweep_done or not self._is_initialized: + logging.info("SweepHook running cache init op.") + sess.run(self._init_op) + if is_sweep_done: + sess.run(self._switch_op) + if is_sweep_done or not self._is_initialized: logging.info("SweepHook running sweep prep ops.") row_sweep = sess.run(self._is_row_sweep_var) prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops @@ -232,13 +195,12 @@ class _SweepHook(session_run_hook.SessionRunHook): self._is_initialized = True - # Request running the switch_ops and the incr_ops - logging.info("Partial fit starting.") - return session_run_hook.SessionRunArgs(fetches=self._fetches) + # Requests running `self._update_op` jointly with the training op. + logging.info("Next fit step starting.") + return session_run_hook.SessionRunArgs(fetches=[self._update_op]) def after_run(self, run_context, run_values): - self._is_sweep_done = run_values.results[0] - logging.info("Partial fit done.") + logging.info("Fit step done.") class _StopAtSweepHook(session_run_hook.SessionRunHook): @@ -360,19 +322,19 @@ def _wals_factorization_model_function(features, labels, mode, params): col_prep_ops = [ model.col_update_prep_gramian_op, model.initialize_col_update_op ] - cache_init_ops = [model.worker_init] + init_ops = [model.worker_init] sweep_hook = _SweepHook( is_row_sweep_var, - train_op, + [train_op, loss], params["num_rows"], params["num_cols"], input_row_indices, input_col_indices, row_prep_ops, col_prep_ops, - cache_init_ops, - completed_sweeps_var,) + init_ops, + completed_sweeps_var) training_hooks = [sweep_hook] if max_sweeps is not None: training_hooks.append(_StopAtSweepHook(max_sweeps)) diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index b5c1bb1151e78a8f19d3c91b57ef3bfd6152893d..8bd72b7025aad80e387171b93b9b264da3ed0f66 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -357,7 +357,7 @@ class WALSMatrixFactorizationTest(test.TestCase): self.assertNear( loss, true_loss, err=.001, - msg="""After row update, eval loss = {}, does not match the true + msg="""After col update, eval loss = {}, does not match the true loss = {}.""".format(loss, true_loss)) @@ -442,7 +442,7 @@ class SweepHookTest(test.TestCase): completed_sweeps_var = variables.Variable(0) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, - self._train_op, + [self._train_op], self._num_rows, self._num_cols, self._input_row_indices_ph, @@ -465,11 +465,9 @@ class SweepHookTest(test.TestCase): 'False.') # Row sweep completed. mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) - self.assertFalse(sess.run(is_row_sweep_var), - msg='Row sweep is complete but is_row_sweep is True.') self.assertTrue(sess.run(completed_sweeps_var) == 1, msg='Completed sweeps should be equal to 1.') - self.assertTrue(sweep_hook._is_sweep_done, + self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is complete but is_sweep_done is False.') # Col init ops should run. Col sweep not completed. mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4])) @@ -478,13 +476,11 @@ class SweepHookTest(test.TestCase): self.assertFalse(sess.run(is_row_sweep_var), msg='Col sweep is not complete but is_row_sweep is ' 'True.') - self.assertFalse(sweep_hook._is_sweep_done, + self.assertFalse(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is not complete but is_sweep_done is True.') # Col sweep completed. mon_sess.run(self._train_op, ind_feed([], [4, 5, 6])) - self.assertTrue(sess.run(is_row_sweep_var), - msg='Col sweep is complete but is_row_sweep is False') - self.assertTrue(sweep_hook._is_sweep_done, + self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is complete but is_sweep_done is False.') self.assertTrue(sess.run(completed_sweeps_var) == 2, msg='Completed sweeps should be equal to 2.') diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 6b0599ddd2def8dd698a1bd152b5be926c6ddf3e..dd882acb8ee35a91f2e67511b1465b3a561d72a6 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -10,9 +10,8 @@ package(default_visibility = [ "//tensorflow:__subpackages__", ]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") @@ -27,6 +26,7 @@ tf_custom_op_py_library( "python/framework/experimental.py", "python/framework/tensor_util.py", "python/ops/__init__.py", + "python/ops/accumulate_n_v2.py", "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", @@ -149,6 +149,31 @@ py_test( ], ) +py_test( + name = "accumulate_n_v2_test", + size = "small", + srcs = ["python/ops/accumulate_n_v2_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) + +py_test( + name = "accumulate_n_v2_eager_test", + size = "small", + srcs = ["python/ops/accumulate_n_v2_eager_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python/eager:backprop", + ], +) + py_test( name = "ops_test", size = "small", diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py index e595e4d90bfd6bef9de5ac0724a18060e7458f8e..92a2a4ff2d1cb41c48312038d82be0b6136f8d41 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util.py @@ -78,9 +78,9 @@ def reduce_sum_n(tensors, name=None): return math_ops.add_n(tensors, name=name_scope) @deprecated(None, - "Please switch to tf.confusion_matrix.remove_squeezable_dimensions. Note " - "that order of the inputs and ouputs of labels and predictions have also " - "been switched.") + 'Please switch to tf.confusion_matrix.remove_squeezable_dimensions.' + 'Note that order of the inputs and outputs of labels and ' + 'predictions have also been switched.') def remove_squeezable_dimensions(predictions, labels, name=None): """Squeeze last dim if ranks of `predictions` and `labels` differ by 1. diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..a0667bd489213cf366e27114a91e8699ed9e7428 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py @@ -0,0 +1,111 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops that will eventually be folded into tensorflow/python/ops/math_ops.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops + + + +def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): + """Returns the element-wise sum of a list of tensors. + + Optionally, pass `shape` and `tensor_dtype` for shape and type checking, + otherwise, these are inferred. + + `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not + wait for all of its inputs to be ready before beginning to sum. This can + save memory if inputs are ready at different times, since minimum temporary + storage is proportional to the output size rather than the inputs size. + + Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. + + For example: + + ```python + a = tf.constant([[1, 2], [3, 4]]) + b = tf.constant([[5, 0], [0, 6]]) + tf.accumulate_n_v2([a, b, a]) # [[7, 4], [6, 14]] + + # Explicitly pass shape and type + tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) + # [[7, 4], + # [6, 14]] + ``` + + Args: + inputs: A list of `Tensor` objects, each with same shape and type. + shape: Shape of elements of `inputs`. + tensor_dtype: The type of `inputs`. + name: A name for the operation (optional). + + Returns: + A `Tensor` of same shape and type as the elements of `inputs`. + + Raises: + ValueError: If `inputs` don't all have same shape and dtype or the shape + cannot be inferred. + """ + _INPUTS_ERR_MSG = ValueError("inputs must be a list of at least one Tensor" + "with the same dtype and shape") + if not inputs or not isinstance(inputs, (list, tuple)): + raise _INPUTS_ERR_MSG + inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) + if not all(isinstance(x, ops.Tensor) for x in inputs): + raise _INPUTS_ERR_MSG + if not all(x.dtype == inputs[0].dtype for x in inputs): + raise _INPUTS_ERR_MSG + if shape is not None: + shape = tensor_shape.as_shape(shape) + else: + shape = tensor_shape.unknown_shape() + for input_tensor in inputs: + if isinstance(input_tensor, ops.Tensor): + shape = shape.merge_with(input_tensor.get_shape()) + + # tensor_dtype is for safety only; operator's output type computed in C++ + if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: + raise TypeError("tensor_dtype is {}, but input is of type {}" + .format(tensor_dtype, inputs[0].dtype)) + + if len(inputs) == 1 and name is None: + return inputs[0] + elif len(inputs) == 1 and name is not None: + return array_ops.identity(inputs[0], name=name) + elif context.in_eager_mode(): + # TemporaryVariable not currently supported in eager mode; fall back + # onto AddN for now. + # TODO(frreiss) remove this once the lifetime of eager variables gets + # addressed + return math_ops.add_n(inputs, name=name) + else: + return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) + +# The following code should eventually be merged into +# tensorflow/python/ops/math_grad.py +@ops.RegisterGradient("AccumulateNV2") +def _AddNGrad(op, grad): + """Same as gradient for AddN. Copies the gradient to all inputs.""" + # Not broadcasting. + return [grad] * len(op.inputs) + diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8c618838bfbcd1b0572c3a57aa6b27c68ee34f0c --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py @@ -0,0 +1,84 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for new version of accumulate_n op that will eventually go into +`ops.math_ops`. + +These test cases spefically exercise the `eager` APIs. They need to be in a +separate file from the remaining tests because eager mode is currently something +you can turn on but can't turn off for the lifetime of the current process.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 + +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context as eager_context +from tensorflow.python.eager import tape + + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test + + + +class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): + """Tests of the new, differentiable version of accumulate_n""" + + def testMinimalEagerMode(self): + forty = constant_op.constant(40) + two = constant_op.constant(2) + answer = av2.accumulate_n_v2([forty, two]) + self.assertEqual(42, answer.numpy()) + + + def testFloat(self): + np.random.seed(12345) + x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] + tf_x = ops.convert_n_to_tensor(x) + with self.test_session(use_gpu=True): + self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).numpy()) + self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).numpy()) + + def testGrad(self): + np.random.seed(42) + num_inputs = 3 + input_vars = [ + resource_variable_ops.ResourceVariable(10.0 * np.random.random()) + for i in range(0, num_inputs) + ] + + def fn(first, second, third): + return av2.accumulate_n_v2([first, second, third]) + + grad_fn = backprop.gradients_function(fn) + grad = grad_fn(input_vars[0], input_vars[1], input_vars[2]) + self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + [elem.numpy() for elem in grad]) + + + +if __name__ == "__main__": + eager_context.enable_eager_execution() + test.main() + diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3386e849d5cb8516ab3b1f6cb0429be3fc2fc960 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py @@ -0,0 +1,123 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for new version of accumulate_n op that will eventually go into +`ops.math_ops`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gradients +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + + +class AccumulateNV2Test(test_util.TensorFlowTestCase): + """Tests of the new, differentiable version of accumulate_n""" + + def testFloat(self): + np.random.seed(12345) + x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] + tf_x = ops.convert_n_to_tensor(x) + with self.test_session(use_gpu=True): + self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).eval()) + self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).eval()) + + def testInt(self): + np.random.seed(54321) + x = [np.random.randint(-128, 128, (5, 4, 3, 2, 1)) for _ in range(6)] + tf_x = ops.convert_n_to_tensor(x) + with self.test_session(use_gpu=True): + self.assertAllEqual(sum(x), av2.accumulate_n_v2(tf_x).eval()) + self.assertAllEqual(x[0] * 6, av2.accumulate_n_v2([tf_x[0]] * 6).eval()) + + def testGrad(self): + np.random.seed(42) + for num_inputs in range(1, 10): + with self.test_session(use_gpu=True) as sess: + input_vars = [ + variables.Variable(10.0 * np.random.random()) + for i in range(0, num_inputs) + ] + accum_n = av2.accumulate_n_v2(input_vars) + sess.run(variables.global_variables_initializer()) + accum_n_grad = gradients.gradients(accum_n, input_vars) + self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + [g.eval() for g in accum_n_grad]) + + # The tests below used to be in a separate class under cwise_ops_test.py, + # which did not run in the default test target. + # Putting them here so that everything that exercises AccumulateNV2 is in + # one place and the default build runs all unit tests. + def testSimple(self): + with self.test_session(): + random_arrays = [ + np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20) + ] + random_tensors = [ + ops.convert_to_tensor( + x, dtype=dtypes_lib.float32) for x in random_arrays + ] + tf_val = av2.accumulate_n_v2(random_tensors) + np_val = random_arrays[0] + for random_array in random_arrays[1:]: + np_val += random_array + self.assertAllClose(np_val, tf_val.eval()) + + def testZeroArgs(self): + with self.test_session(): + with self.assertRaises(ValueError): + tf_val = av2.accumulate_n_v2([]) + tf_val.eval() + + def testWrongShape(self): + with self.test_session(): + with self.assertRaises(ValueError): + a = variables.Variable(0.2) + b = variables.Variable(0.1) + tf_val = av2.accumulate_n_v2([a,b], shape=[2,2]) # Should be shape=[] + + def testIncompatibleShapes(self): + with self.test_session(): + with self.assertRaises(ValueError): + a = variables.Variable(np.array([0.1,0.2])) + b = variables.Variable(np.array([[0.3],[0.4]])) + tf_val = av2.accumulate_n_v2([a,b]) + + def testWrongType(self): + with self.test_session(): + with self.assertRaises(TypeError): + a = variables.Variable(0.2, dtype=np.float32) + b = variables.Variable(0.1, dtype=np.float32) + tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) + + def testWrongTypeOneInput(self): + # Scenario that used to trigger a bug, even when testWrongType() worked + with self.test_session(): + with self.assertRaises(TypeError): + a = variables.Variable(0.2, dtype=np.float32) + tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 9275d5a22b2697c37414fba2f6176f708808e60c..e4c39739f7fc653b68e82c994fc69e3e168f65f9 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -298,6 +298,17 @@ void LaunchFusedConv2DBiasActivationOp:: constexpr int rank = is_int8x4 ? 5 : 4; constexpr int vect = is_int8x4 ? 4 : 1; + if (is_int8x4) { + int cc_major, cc_minor; + stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + OP_REQUIRES( + ctx, cc_major >= 6 && cc_minor >= 1, + errors::Unimplemented( + "FusedConv2DBiasActivation for int8 is only supported on GPUs with " + "compute capability 6.1 or later.")); + } + const int batch_size = GetTensorDim(conv_input_param, data_format, 'N'); int conv_input_rows = GetTensorDim(conv_input_param, data_format, 'H'); int conv_input_cols = GetTensorDim(conv_input_param, data_format, 'W'); @@ -493,42 +504,37 @@ void LaunchFusedConv2DBiasActivationOp:: dnn::AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find( fused_conv_parameters, &algorithm_config)) { - std::vector algorithms; + std::vector algorithms; CHECK(stream->parent()->GetConvolveAlgorithms( fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), &algorithms)); dnn::ProfileResult best_result; dnn::ProfileResult best_result_no_scratch; - // TODO(benbarsdell): Ideally this should not attempt using tensor op math - // if it's not enabled. - for (bool use_tensor_ops : {false, true}) { - for (auto algo_index : algorithms) { - // TODO(zhengxq): profile each algorithm multiple times to better - // accuracy. - dnn::AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops); - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); - dnn::ProfileResult profile_result; - bool cudnn_launch_status = - stream - ->ThenFusedConvolveWithAlgorithm( - conv_input_desc, conv_input_ptr, conv_input_scale, - filter_desc, filter_ptr, conv_desc, side_input_ptr, - side_input_scale, bias_desc, bias_ptr, - dnn::ActivationMode::kRelu, output_desc, &output_ptr, - &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), - &profile_result) - .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalByteSize() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_no_scratch.elapsed_time_in_ms()) { - best_result_no_scratch = profile_result; - } + for (auto profile_algorithm : algorithms) { + // TODO(zhengxq): profile each algorithm multiple times to better + // accuracy. + CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + dnn::ProfileResult profile_result; + bool cudnn_launch_status = + stream + ->ThenFusedConvolveWithAlgorithm( + conv_input_desc, conv_input_ptr, conv_input_scale, + filter_desc, filter_ptr, conv_desc, side_input_ptr, + side_input_scale, bias_desc, bias_ptr, + dnn::ActivationMode::kRelu, output_desc, &output_ptr, + &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), + &profile_result) + .ok(); + if (cudnn_launch_status) { + if (profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalByteSize() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_no_scratch.elapsed_time_in_ms()) { + best_result_no_scratch = profile_result; } } } diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 54dbb11b6ebcfac8f8d687863f85a8d890fd4fb3..27a5d6ec31f0df5f0f3a435185f50a6c88122b19 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -14,6 +14,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":estimator", ":eval", ":features", ":losses", @@ -86,6 +87,17 @@ py_library( ], ) +py_library( + name = "estimator", + srcs = ["python/estimator/__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":gan_estimator", + ":head", + "//tensorflow/python:util", + ], +) + py_library( name = "losses", srcs = ["python/losses/__init__.py"], @@ -369,6 +381,90 @@ py_test( ], ) +py_library( + name = "head", + srcs = [ + "python/estimator/python/head.py", + "python/estimator/python/head_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":namedtuples", + ":train", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "head_test", + srcs = ["python/estimator/python/head_test.py"], + shard_count = 1, + srcs_version = "PY2AND3", + deps = [ + ":head", + ":namedtuples", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_library( + name = "gan_estimator", + srcs = [ + "python/estimator/python/gan_estimator.py", + "python/estimator/python/gan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":head", + ":namedtuples", + ":summaries", + ":train", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "gan_estimator_test", + srcs = ["python/estimator/python/gan_estimator_test.py"], + shard_count = 1, + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":gan_estimator", + ":namedtuples", + ":tuple_losses", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/learn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 10458a2458384c8f589183003256db24d69742d7..5d74df3ef70be03060e9cd30249359ee07b4b83a 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -51,9 +51,10 @@ network to evaluate your unconditional generative model. You can also also use your own pretrained classifier for more specific performance numbers, or use other methods for evaluating conditional generative models. -* [examples](https://github.com/tensorflow/models/tree/master/gan/): +* examples (coming soon): See examples of how to use TFGAN to make GAN training easier, or use the more complicated examples to jumpstart your -own project. +own project. These include unconditional and conditional GANs, InfoGANs, +adversarial losses on existing networks, and image-to-image translation. ## Training a GAN model diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index 67eee771d040995449329dde0b0cb990793176ec..dff361fdc42708ea69999c2def4721f9d49fcf14 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # Collapse TFGAN into a tiered namespace. +from tensorflow.contrib.gan.python import estimator from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin from tensorflow.contrib.gan.python import features from tensorflow.contrib.gan.python import losses @@ -33,6 +34,7 @@ from tensorflow.contrib.gan.python.train import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + 'estimator', 'eval', 'features', 'losses', diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4a18228039cb4f2c06e0333f4b8408f1f631e9 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFGAN grouped API. Please see README.md for details and usage.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Collapse `estimator` into a single namespace. +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.gan.python.estimator.python import gan_estimator +from tensorflow.contrib.gan.python.estimator.python import head + +from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * +from tensorflow.contrib.gan.python.estimator.python.head import * +# pylint: enable=unused-import,wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'gan_estimator', + 'head', +] + gan_estimator.__all__ + head.__all__ +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0e48540915d1de7e249f8640193366f37baa92 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `GANEstimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.gan_estimator_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = gan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..e89993991a389d68254a95aded2d771f4c2627be --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -0,0 +1,273 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TFGAN-backed GAN Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import enum + +from tensorflow.contrib.framework.python.ops import variables as variable_lib +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.contrib.gan.python.estimator.python import head as head_lib +from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope + + +__all__ = [ + 'GANEstimator', + 'SummaryType' +] + + +class SummaryType(enum.IntEnum): + NONE = 0 + VARIABLES = 1 + IMAGES = 2 + IMAGE_COMPARISON = 3 + + +_summary_type_map = { + SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries, + SummaryType.IMAGES: tfgan_summaries.add_gan_model_image_summaries, + SummaryType.IMAGE_COMPARISON: tfgan_summaries.add_image_comparison_summaries, # pylint:disable=line-too-long +} + + +# TODO(joelshor): For now, this only supports 1:1 generator:discriminator +# training sequentially. Find a nice way to expose options to the user without +# exposing internals. +class GANEstimator(estimator.Estimator): + """An estimator for Generative Adversarial Networks (GANs). + + This Estimator is backed by TFGAN. + + Example: + + ```python + import tensorflow as tf + tfgan = tf.contrib.gan + + # See TFGAN's `train.py` for a description of the generator and + # discriminator API. + def generator_fn(generator_inputs): + ... + return generated_data + + def discriminator_fn(data, conditioning): + ... + return logits + + # Create GAN estimator. + gan_estimator = estimator.GANEstimator( + model_dir, + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, + generator_optimizer=tf.train.AdamOptimizier(0.1, 0.5), + discriminator_optimizer=tf.train.AdamOptimizier(0.1, 0.5)) + + # Train estimator. + gan_estimator.train(train_input_fn, steps) + + # Evaluate resulting estimator. + gan_estimator.evaluate(eval_input_fn) + + # Generate samples from generator. + predictions = np.array([ + x for x in gan_estimator.predict(predict_input_fn)]) + ``` + """ + + def __init__(self, + model_dir=None, + generator_fn=None, + discriminator_fn=None, + generator_loss_fn=None, + discriminator_loss_fn=None, + generator_optimizer=None, + discriminator_optimizer=None, + add_summaries=None, + use_loss_summaries=True, + config=None): + """Initializes a GANEstimator instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + generator_fn: A python function that takes a Tensor, Tensor list, or + Tensor dictionary as inputs and returns the outputs of the GAN + generator. See `TFGAN` for more details and examples. + discriminator_fn: A python function that takes the output of + `generator_fn` or real data in the GAN setup, and `generator_inputs`. + Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details + and examples. + generator_loss_fn: The loss function on the generator. Takes a `GANModel` + tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + `GANModel` tuple. + generator_optimizer: The optimizer for generator updates, or a function + that takes no arguments and returns an optimizer. This function will + be called when the default graph is the `GANEstimator`'s graph, so + utilities like `tf.contrib.framework.get_or_create_global_step` will + work. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. + use_loss_summaries: If `True`, add loss summaries. If `False`, does not. + If `None`, uses defaults. + config: `RunConfig` object to configure the runtime settings. + """ + # TODO(joelshor): Explicitly validate inputs. + + def _model_fn(features, labels, mode): + gopt = (generator_optimizer() if callable(generator_optimizer) else + generator_optimizer) + dopt = (discriminator_optimizer() if callable(discriminator_optimizer) + else discriminator_optimizer) + gan_head = head_lib.gan_head( + generator_loss_fn, discriminator_loss_fn, gopt, dopt, + use_loss_summaries) + return _gan_model_fn( + features, labels, mode, generator_fn, discriminator_fn, gan_head, + add_summaries) + + super(GANEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) + + +def _use_check_shapes(real_data): + """Determines whether TFGAN should check Tensor shapes.""" + return isinstance(real_data, ops.Tensor) + + +def _gan_model_fn( + features, + labels, + mode, + generator_fn, + discriminator_fn, + head, + add_summaries=None, + generator_scope_name='Generator'): + """The `model_fn` for the GAN estimator. + + We make the following convention: + features -> TFGAN's `generator_inputs` + labels -> TFGAN's `real_data` + + Args: + features: A dictionary to feed to generator. In the unconditional case, + this might be just `noise`. In the conditional GAN case, this + might be the generator's conditioning. The `generator_fn` determines + what the required keys are. + labels: Real data. Can be any structure, as long as `discriminator_fn` + can accept it for the first argument. + mode: Defines whether this is training, evaluation or prediction. + See `ModeKeys`. + generator_fn: A python lambda that takes `generator_inputs` as inputs and + returns the outputs of the GAN generator. + discriminator_fn: A python lambda that takes `real_data`/`generated data` + and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. + head: A `Head` instance suitable for GANs. + add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. + generator_scope_name: The name of the generator scope. We need this to be + the same for GANModels produced by TFGAN's `train.gan_model` and the + manually constructed ones for predictions. + + Returns: + `ModelFnOps` + + Raises: + ValueError: If `labels` isn't `None` during prediction. + """ + real_data = labels + generator_inputs = features + + if mode == model_fn_lib.ModeKeys.TRAIN: + gan_model = _make_train_gan_model( + generator_fn, discriminator_fn, real_data, generator_inputs, + generator_scope_name, add_summaries) + elif mode == model_fn_lib.ModeKeys.EVAL: + gan_model = _make_eval_gan_model( + generator_fn, discriminator_fn, real_data, generator_inputs, + generator_scope_name, add_summaries) + else: + if real_data is not None: + raise ValueError('`labels` must be `None` when mode is `predict`. ' + 'Instead, found %s' % real_data) + gan_model = _make_prediction_gan_model( + generator_inputs, generator_fn, generator_scope_name) + + return head.create_estimator_spec( + features=None, + mode=mode, + logits=gan_model, + labels=None) + + +def _make_train_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries): + """Make a `GANModel` for training.""" + gan_model = tfgan_train.gan_model( + generator_fn, + discriminator_fn, + real_data, + generator_inputs, + generator_scope=generator_scope, + check_shapes=_use_check_shapes(real_data)) + if add_summaries: + if not isinstance(add_summaries, (tuple, list)): + add_summaries = [add_summaries] + with ops.name_scope(None): + for summary_type in add_summaries: + _summary_type_map[summary_type](gan_model) + + return gan_model + + +def _make_eval_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries): + """Make a `GANModel` for evaluation.""" + return _make_train_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries) + + +def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): + """Make a `GANModel` from just the generator.""" + with variable_scope.variable_scope(generator_scope) as gen_scope: + generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access + generated_data = generator_fn(generator_inputs) + generator_variables = variable_lib.get_trainable_variables(gen_scope) + + return tfgan_tuples.GANModel( + generator_inputs, + generated_data, + generator_variables, + gen_scope, + generator_fn, + real_data=None, + discriminator_real_outputs=None, + discriminator_gen_outputs=None, + discriminator_variables=None, + discriminator_scope=None, + discriminator_fn=None) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1bfdce9ee94d4d05d5186cd999361662bc0e3f85 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -0,0 +1,327 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TFGAN's estimator.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib import layers +from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator +from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses +from tensorflow.contrib.learn.python.learn.learn_io import graph_io +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import input as input_lib +from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training +from tensorflow.python.training import training_util + + +def generator_fn(noise_dict): + noise = noise_dict['x'] + return layers.fully_connected(noise, noise.shape[1].value) + + +def discriminator_fn(data, _): + return layers.fully_connected(data, 1) + + +def mock_head(testcase, expected_generator_inputs, expected_real_data, + generator_scope_name): + """Returns a mock head that validates logits values and variable names.""" + discriminator_scope_name = 'Discriminator' # comes from TFGAN defaults + generator_var_names = set([ + '%s/fully_connected/weights:0' % generator_scope_name, + '%s/fully_connected/biases:0' % generator_scope_name]) + discriminator_var_names = set([ + '%s/fully_connected/weights:0' % discriminator_scope_name, + '%s/fully_connected/biases:0' % discriminator_scope_name]) + + def _create_estimator_spec(features, mode, logits, labels): + gan_model = logits # renaming for clarity + is_predict = mode == model_fn_lib.ModeKeys.PREDICT + testcase.assertIsNone(features) + testcase.assertIsNone(labels) + testcase.assertIsInstance(gan_model, namedtuples.GANModel) + + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + expected_var_names = (generator_var_names if is_predict else + generator_var_names | discriminator_var_names) + testcase.assertItemsEqual(expected_var_names, + [var.name for var in trainable_vars]) + + assertions = [] + def _or_none(x): + return None if is_predict else x + testcase.assertEqual(expected_generator_inputs, gan_model.generator_inputs) + # TODO(joelshor): Add check on `generated_data`. + testcase.assertItemsEqual( + generator_var_names, + set([x.name for x in gan_model.generator_variables])) + testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name) + testcase.assertEqual(generator_fn, gan_model.generator_fn) + testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data) + # TODO(joelshor): Add check on `discriminator_real_outputs`. + # TODO(joelshor): Add check on `discriminator_gen_outputs`. + if is_predict: + testcase.assertIsNone(gan_model.discriminator_scope) + else: + testcase.assertEqual(discriminator_scope_name, + gan_model.discriminator_scope.name) + testcase.assertEqual(_or_none(discriminator_fn), gan_model.discriminator_fn) + + with ops.control_dependencies(assertions): + if mode == model_fn_lib.ModeKeys.TRAIN: + return model_fn_lib.EstimatorSpec( + mode=mode, loss=array_ops.zeros([]), + train_op=control_flow_ops.no_op(), training_hooks=[]) + elif mode == model_fn_lib.ModeKeys.EVAL: + return model_fn_lib.EstimatorSpec( + mode=mode, predictions=gan_model.generated_data, + loss=array_ops.zeros([])) + elif mode == model_fn_lib.ModeKeys.PREDICT: + return model_fn_lib.EstimatorSpec( + mode=mode, predictions=gan_model.generated_data) + else: + testcase.fail('Invalid mode: {}'.format(mode)) + + head = test.mock.NonCallableMagicMock(spec=head_lib._Head) + head.create_estimator_spec = test.mock.MagicMock( + wraps=_create_estimator_spec) + + return head + + +class GANModelFnTest(test.TestCase): + """Tests that _gan_model_fn passes expected logits to mock head.""" + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_logits_helper(self, mode): + """Tests that the expected logits are passed to mock head.""" + with ops.Graph().as_default(): + training_util.get_or_create_global_step() + generator_inputs = {'x': array_ops.zeros([5, 4])} + real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else + array_ops.zeros([5, 4])) + generator_scope_name = 'generator' + head = mock_head(self, + expected_generator_inputs=generator_inputs, + expected_real_data=real_data, + generator_scope_name=generator_scope_name) + estimator_spec = estimator._gan_model_fn( + features=generator_inputs, + labels=real_data, + mode=mode, + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_scope_name=generator_scope_name, + head=head) + with monitored_session.MonitoredTrainingSession( + checkpoint_dir=self._model_dir) as sess: + if mode == model_fn_lib.ModeKeys.TRAIN: + sess.run(estimator_spec.train_op) + elif mode == model_fn_lib.ModeKeys.EVAL: + sess.run(estimator_spec.loss) + elif mode == model_fn_lib.ModeKeys.PREDICT: + sess.run(estimator_spec.predictions) + else: + self.fail('Invalid mode: {}'.format(mode)) + + def test_logits_predict(self): + self._test_logits_helper(model_fn_lib.ModeKeys.PREDICT) + + def test_logits_eval(self): + self._test_logits_helper(model_fn_lib.ModeKeys.EVAL) + + def test_logits_train(self): + self._test_logits_helper(model_fn_lib.ModeKeys.TRAIN) + + +# TODO(joelshor): Add pandas test. +class GANEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, prediction_size, + lr_decay=False): + def make_opt(): + gstep = training_util.get_or_create_global_step() + lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) + return training.GradientDescentOptimizer(lr) + + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + est = estimator.GANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=gopt, + discriminator_optimizer=dopt, + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predictions = np.array([x for x in est.predict(predict_input_fn)]) + + self.assertAllEqual(prediction_size, predictions.shape) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + input_dim = 4 + batch_size = 5 + data = np.zeros([batch_size, input_dim]) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + prediction_size=[batch_size, input_dim]) + + def test_numpy_input_fn_lrdecay(self): + """Tests complete flow with numpy_input_fn.""" + input_dim = 4 + batch_size = 5 + data = np.zeros([batch_size, input_dim]) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + prediction_size=[batch_size, input_dim], + lr_decay=True) + + def test_input_fn_from_parse_example(self): + """Tests complete flow with input_fn constructed from parse_example.""" + input_dim = 4 + batch_size = 6 + data = np.zeros([batch_size, input_dim]) + + serialized_examples = [] + for datum in data: + example = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'x': feature_pb2.Feature( + float_list=feature_pb2.FloatList(value=datum)), + 'y': feature_pb2.Feature( + float_list=feature_pb2.FloatList(value=datum)), + })) + serialized_examples.append(example.SerializeToString()) + + feature_spec = { + 'x': parsing_ops.FixedLenFeature([input_dim], dtypes.float32), + 'y': parsing_ops.FixedLenFeature([input_dim], dtypes.float32), + } + def _train_input_fn(): + feature_map = parsing_ops.parse_example( + serialized_examples, feature_spec) + _, features = graph_io.queue_parsed_features(feature_map) + labels = features.pop('y') + return features, labels + def _eval_input_fn(): + feature_map = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + _, features = graph_io.queue_parsed_features(feature_map) + labels = features.pop('y') + return features, labels + def _predict_input_fn(): + feature_map = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + _, features = graph_io.queue_parsed_features(feature_map) + features.pop('y') + return features, None + + self._test_complete_flow( + train_input_fn=_train_input_fn, + eval_input_fn=_eval_input_fn, + predict_input_fn=_predict_input_fn, + prediction_size=[batch_size, input_dim]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/estimator/python/head.py b/tensorflow/contrib/gan/python/estimator/python/head.py new file mode 100644 index 0000000000000000000000000000000000000000..3225d6f41a1c17bfc8c57494dd683aaab45b10f3 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/head.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `GANEstimator`'s loss.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import head_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.head_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = head_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..204c646e194319c0e63599da0b2a4909ef270ef3 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -0,0 +1,206 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TFGAN-backed GAN Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.canned import head +from tensorflow.python.framework import ops + +__all__ = [ + 'GANHead', + 'gan_head', +] + + +def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, + discriminator_optimizer, use_loss_summaries=True, + get_hooks_fn=tfgan_train.get_sequential_train_hooks(), + name=None): + """Creates a `GANHead`. + + Args: + generator_loss_fn: A TFGAN loss function for the generator. Takes a + `GANModel` and returns a scalar. + discriminator_loss_fn: Same as `generator_loss_fn`, but for the + discriminator. + generator_optimizer: The optimizer for generator updates. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + use_loss_summaries: If `True`, add loss summaries. If `False`, does not. + If `None`, uses defaults. + get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list + of hooks. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. + + Returns: + An instance of `GANHead`. + """ + return GANHead(generator_loss_fn=generator_loss_fn, + discriminator_loss_fn=discriminator_loss_fn, + generator_optimizer=generator_optimizer, + discriminator_optimizer=discriminator_optimizer, + use_loss_summaries=use_loss_summaries, + get_hooks_fn=get_hooks_fn, + name=name) + + +class GANHead(head._Head): # pylint: disable=protected-access + """`Head` for a GAN.""" + + def __init__(self, generator_loss_fn, discriminator_loss_fn, + generator_optimizer, discriminator_optimizer, + use_loss_summaries=True, + get_hooks_fn=tfgan_train.get_sequential_train_hooks(), + name=None): + """`Head` for GAN training. + + Args: + generator_loss_fn: A TFGAN loss function for the generator. Takes a + `GANModel` and returns a scalar. + discriminator_loss_fn: Same as `generator_loss_fn`, but for the + discriminator. + generator_optimizer: The optimizer for generator updates. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + use_loss_summaries: If `True`, add loss summaries. If `False`, does not. + If `None`, uses defaults. + get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list + of hooks. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. + """ + # TODO(joelshor): Validate inputs. + + if use_loss_summaries in [True, False]: + generator_loss_fn = functools.partial( + generator_loss_fn, add_summaries=use_loss_summaries) + discriminator_loss_fn = functools.partial( + discriminator_loss_fn, add_summaries=use_loss_summaries) + self._generator_loss_fn = generator_loss_fn + self._discriminator_loss_fn = discriminator_loss_fn + self._generator_optimizer = generator_optimizer + self._discriminator_optimizer = discriminator_optimizer + self._get_hooks_fn = get_hooks_fn + + @property + def name(self): + return self._name + + @property + def logits_dimension(self): + return None + + def create_loss(self, features, mode, logits, labels): + """Returns a GANLoss tuple from the provided GANModel. + + See `Head` for more details. + + Args: + features: Input `dict` of `Tensor` objects. Unused. + mode: Estimator's `ModeKeys`. + logits: A GANModel tuple. + labels: Must be `None`. + + Returns: + A GANLoss tuple. + + """ + _validate_logits_and_labels(logits, labels) + del mode, labels, features # unused for this head. + gan_model = logits # rename variable for clarity + return tfgan_tuples.GANLoss( + generator_loss=self._generator_loss_fn(gan_model), + discriminator_loss=self._discriminator_loss_fn(gan_model)) + + def create_estimator_spec( + self, features, mode, logits, labels=None, + train_op_fn=tfgan_train.gan_train_ops): + """Returns `EstimatorSpec` that a model_fn can return. + + See `Head` for more details. + + Args: + features: Must be `None`. + mode: Estimator's `ModeKeys`. + logits: A GANModel tuple. + labels: Must be `None`. + train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer, + and discriminator optimizer, and returns a `GANTrainOps` tuple. For + example, this function can come from TFGAN's `train.py` library, or can + be custom. + + Returns: + `EstimatorSpec`. + + Raises: + ValueError: If `features` isn't `None`. + ValueError: If `train_op_fn` isn't provided in train mode. + """ + _validate_logits_and_labels(logits, labels) + if features is not None: + raise ValueError('`features` should be `None`. Instead, found: %s' % + features) + gan_model = logits # rename variable for clarity + with ops.name_scope('GANHead'): + if mode == model_fn_lib.ModeKeys.PREDICT: + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.PREDICT, + predictions=gan_model.generated_data) + elif mode == model_fn_lib.ModeKeys.EVAL: + gan_loss = self.create_loss( + features=None, mode=mode, logits=gan_model, labels=None) + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, + predictions=gan_model.generated_data, + loss=scalar_loss, + # TODO(joelshor): Add metrics. If head name provided, append it to + # metric keys. + eval_metric_ops={}) + elif mode == model_fn_lib.ModeKeys.TRAIN: + if train_op_fn is None: + raise ValueError('train_op_fn can not be None.') + gan_loss = self.create_loss(None, mode, gan_model, None) + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + train_ops = train_op_fn(gan_model, gan_loss, self._generator_optimizer, + self._discriminator_optimizer) + training_hooks = self._get_hooks_fn(train_ops) + return model_fn_lib.EstimatorSpec( + loss=scalar_loss, + mode=model_fn_lib.ModeKeys.TRAIN, + train_op=train_ops.global_step_inc_op, + training_hooks=training_hooks) + else: + raise ValueError('Mode not recognized: %s' % mode) + + +def _validate_logits_and_labels(logits, labels): + if labels is not None: + raise ValueError('`GANHead`\'s `create_estimator_spec` input `labels` must ' + 'be `None`. Instead, found: %s' % labels) + + if not isinstance(logits, tfgan_tuples.GANModel): + raise ValueError('`GANHead`\'s `create_estimator_spec` input `logits` must ' + 'be an instnace of a `GANModel`. Instead, found: %s' % + logits) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8168f005cd1105886390a2384a936663c83fa5f5 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -0,0 +1,85 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TFGAN's head.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python.estimator.python import head + +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument + return math_ops.reduce_sum(gan_model.discriminator_real_outputs - + gan_model.discriminator_gen_outputs) + + +def get_gan_model(): + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) + with variable_scope.variable_scope('discriminator') as dis_scope: + dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) + return tfgan_tuples.GANModel( + generator_inputs=None, + generated_data=array_ops.ones([3, 4]), + generator_variables=[gen_var], + generator_scope=gen_scope, + generator_fn=None, + real_data=None, + discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, + discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, + discriminator_variables=[dis_var], + discriminator_scope=dis_scope, + discriminator_fn=None) + + +class GANHeadTest(test.TestCase): + + def setUp(self): + super(GANHeadTest, self).setUp() + self.gan_head = head.gan_head( + generator_loss_fn=dummy_loss, + discriminator_loss_fn=dummy_loss, + generator_optimizer=training.GradientDescentOptimizer(1.0), + discriminator_optimizer=training.GradientDescentOptimizer(1.0)) + self.assertTrue(isinstance(self.gan_head, head.GANHead)) + + def _test_modes_helper(self, mode): + self.gan_head.create_estimator_spec( + features=None, + mode=mode, + logits=get_gan_model()) + + def test_modes_predict(self): + self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) + + def test_modes_eval(self): + self._test_modes_helper(model_fn_lib.ModeKeys.EVAL) + + def test_modes_train(self): + self._test_modes_helper(model_fn_lib.ModeKeys.TRAIN) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 4ef0d2d565edcfe998d6f2e3336eeb07520d21cc..6074694f8b87f65a2b2f8a3c4d7ac6b93482afd3 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -59,7 +59,7 @@ __all__ = [ INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v3_2017_09_13.tar.gz' INCEPTION_FROZEN_GRAPH = 'frozen_inception_v3.pb' -INCEPTION_V3_INPUT = 'inputs' +INCEPTION_V3_INPUT = 'input' INCEPTION_V3_OUTPUT = 'InceptionV3/Logits/SpatialSqueeze:0' INCEPTION_V3_FINAL_POOL = 'InceptionV3/Logits/AvgPool_1a_8x8/AvgPool:0' _INCEPTION_V3_NUM_CLASSES = 1001 @@ -317,13 +317,22 @@ def classifier_score(images, classifier_fn, num_batches=1): name='RunClassifier') logits = array_ops.concat(array_ops.unstack(logits), 0) logits.shape.assert_has_rank(2) + + # Use maximum precision for best results. + logits_dtype = logits.dtype + if logits_dtype != dtypes.float64: + logits = math_ops.cast(logits, dtypes.float64) + p = nn_ops.softmax(logits) q = math_ops.reduce_mean(p, axis=0) kl = _kl_divergence(p, logits, q) kl.shape.assert_has_rank(1) log_score = math_ops.reduce_mean(kl) + final_score = math_ops.exp(log_score) - return math_ops.exp(log_score) + if logits_dtype != dtypes.float64: + final_score = math_ops.cast(final_score, dtypes.float64) + return final_score inception_score = functools.partial( diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index cf33a9fe83a5e39ab62452a814e3d907abc1c284..30285964a53c388d4f9aaf65b6cabed362b3b012 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -68,7 +68,7 @@ def _expected_trace_sqrt_product(sigma, sigma_v): # A dummy GraphDef string with the minimum number of Ops. graphdef_string = """ node { - name: "inputs" + name: "input" op: "Placeholder" attr { key: "dtype" diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 940b5236276c3e06bf030e310f7453e93c7e3d32..508b4d20d8767f42246a0d0c87f911b7ac612f45 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -38,7 +38,7 @@ def _assert_is_image(data): data.shape[1:].assert_is_fully_defined() -def add_gan_model_image_summaries(gan_model, grid_size=10): +def add_gan_model_image_summaries(gan_model, grid_size=4): """Adds image summaries for real and fake images. Args: diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 29bd72d4db1920a9bb68d2ca292d81883c0ca67c..940762cf2aa0f473cd41d9d543e2773b565a5248 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -217,21 +217,25 @@ def acgan_discriminator_loss( Raises: TypeError: If the discriminator does not output a tuple. """ - loss_on_generated = losses.softmax_cross_entropy( - one_hot_labels, discriminator_gen_classification_logits, - weights=generated_weights, scope=scope, loss_collection=None, - reduction=reduction) - loss_on_real = losses.softmax_cross_entropy( - one_hot_labels, discriminator_real_classification_logits, - weights=real_weights, label_smoothing=label_smoothing, scope=scope, - loss_collection=None, reduction=reduction) - loss = loss_on_generated + loss_on_real - util.add_loss(loss, loss_collection) + with ops.name_scope( + scope, 'acgan_discriminator_loss', + (discriminator_real_classification_logits, + discriminator_gen_classification_logits, one_hot_labels)) as scope: + loss_on_generated = losses.softmax_cross_entropy( + one_hot_labels, discriminator_gen_classification_logits, + weights=generated_weights, scope=scope, loss_collection=None, + reduction=reduction) + loss_on_real = losses.softmax_cross_entropy( + one_hot_labels, discriminator_real_classification_logits, + weights=real_weights, label_smoothing=label_smoothing, scope=scope, + loss_collection=None, reduction=reduction) + loss = loss_on_generated + loss_on_real + util.add_loss(loss, loss_collection) - if add_summaries: - summary.scalar('discriminator_gen_ac_loss', loss_on_generated) - summary.scalar('discriminator_real_ac_loss', loss_on_real) - summary.scalar('discriminator_ac_loss', loss) + if add_summaries: + summary.scalar('discriminator_gen_ac_loss', loss_on_generated) + summary.scalar('discriminator_real_ac_loss', loss_on_real) + summary.scalar('discriminator_ac_loss', loss) return loss @@ -275,12 +279,16 @@ def acgan_generator_loss( ValueError: if arg module not either `generator` or `discriminator` TypeError: if the discriminator does not output a tuple. """ - loss = losses.softmax_cross_entropy( - one_hot_labels, discriminator_gen_classification_logits, weights=weights, - scope=scope, loss_collection=loss_collection, reduction=reduction) + with ops.name_scope( + scope, 'acgan_generator_loss', + (discriminator_gen_classification_logits, one_hot_labels)) as scope: + loss = losses.softmax_cross_entropy( + one_hot_labels, discriminator_gen_classification_logits, + weights=weights, scope=scope, loss_collection=loss_collection, + reduction=reduction) - if add_summaries: - summary.scalar('generator_ac_loss', loss) + if add_summaries: + summary.scalar('generator_ac_loss', loss) return loss @@ -289,7 +297,6 @@ def acgan_generator_loss( # GANs` (https://arxiv.org/abs/1704.00028). -# TODO(joelshor): Figure out why this function can't be inside a name scope. def wasserstein_gradient_penalty( real_data, generated_data, @@ -331,46 +338,50 @@ def wasserstein_gradient_penalty( Raises: ValueError: If the rank of data Tensors is unknown. """ - if generated_data.shape.ndims is None: - raise ValueError('`generated_data` can\'t have unknown rank.') - if real_data.shape.ndims is None: - raise ValueError('`real_data` can\'t have unknown rank.') - - differences = generated_data - real_data - batch_size = differences.shape[0].value or array_ops.shape(differences)[0] - alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1) - alpha = random_ops.random_uniform(shape=alpha_shape) - interpolates = real_data + (alpha * differences) - - # Reuse variables if a discriminator scope already exists. - reuse = False if discriminator_scope is None else True - with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope', - reuse=reuse): - disc_interpolates = discriminator_fn(interpolates, generator_inputs) - - if isinstance(disc_interpolates, tuple): - # ACGAN case: disc outputs more than one tensor - disc_interpolates = disc_interpolates[0] - - gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0] - gradient_squares = math_ops.reduce_sum( - math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims))) - # Propagate shape information, if possible. - if isinstance(batch_size, int): - gradient_squares.set_shape([ - batch_size] + gradient_squares.shape.as_list()[1:]) - # For numerical stability, add epsilon to the sum before taking the square - # root. Note tf.norm does not add epsilon. - slopes = math_ops.sqrt(gradient_squares + epsilon) - penalties = math_ops.square(slopes - 1.0) - penalty = losses.compute_weighted_loss( - penalties, weights, scope=scope, loss_collection=loss_collection, - reduction=reduction) + with ops.name_scope(scope, 'wasserstein_gradient_penalty', + (real_data, generated_data)) as scope: + real_data = ops.convert_to_tensor(real_data) + generated_data = ops.convert_to_tensor(generated_data) + if real_data.shape.ndims is None: + raise ValueError('`real_data` can\'t have unknown rank.') + if generated_data.shape.ndims is None: + raise ValueError('`generated_data` can\'t have unknown rank.') + + differences = generated_data - real_data + batch_size = differences.shape[0].value or array_ops.shape(differences)[0] + alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1) + alpha = random_ops.random_uniform(shape=alpha_shape) + interpolates = real_data + (alpha * differences) + + with ops.name_scope(None): # Clear scope so update ops are added properly. + # Reuse variables if variables already exists. + with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope', + reuse=variable_scope.AUTO_REUSE): + disc_interpolates = discriminator_fn(interpolates, generator_inputs) + + if isinstance(disc_interpolates, tuple): + # ACGAN case: disc outputs more than one tensor + disc_interpolates = disc_interpolates[0] + + gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0] + gradient_squares = math_ops.reduce_sum( + math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims))) + # Propagate shape information, if possible. + if isinstance(batch_size, int): + gradient_squares.set_shape([ + batch_size] + gradient_squares.shape.as_list()[1:]) + # For numerical stability, add epsilon to the sum before taking the square + # root. Note tf.norm does not add epsilon. + slopes = math_ops.sqrt(gradient_squares + epsilon) + penalties = math_ops.square(slopes - 1.0) + penalty = losses.compute_weighted_loss( + penalties, weights, scope=scope, loss_collection=loss_collection, + reduction=reduction) - if add_summaries: - summary.scalar('gradient_penalty_loss', penalty) + if add_summaries: + summary.scalar('gradient_penalty_loss', penalty) - return penalty + return penalty # Original losses from `Generative Adversarial Nets` @@ -544,7 +555,7 @@ def modified_generator_loss( discriminator_gen_outputs, label_smoothing=0.0, weights=1.0, - scope='generator_modified_loss', + scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=False): @@ -574,12 +585,15 @@ def modified_generator_loss( Returns: A loss Tensor. The shape depends on `reduction`. """ - loss = losses.sigmoid_cross_entropy( - array_ops.ones_like(discriminator_gen_outputs), discriminator_gen_outputs, - weights, label_smoothing, scope, loss_collection, reduction) + with ops.name_scope(scope, 'generator_modified_loss', + [discriminator_gen_outputs]) as scope: + loss = losses.sigmoid_cross_entropy( + array_ops.ones_like(discriminator_gen_outputs), + discriminator_gen_outputs, weights, label_smoothing, scope, + loss_collection, reduction) - if add_summaries: - summary.scalar('generator_modified_loss', loss) + if add_summaries: + summary.scalar('generator_modified_loss', loss) return loss @@ -737,7 +751,7 @@ def mutual_information_penalty( structured_generator_inputs, predicted_distributions, weights=1.0, - scope='generator_modified_loss', + scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=False): @@ -765,15 +779,16 @@ def mutual_information_penalty( _validate_information_penalty_inputs( structured_generator_inputs, predicted_distributions) - # Calculate the negative log-likelihood of the reconstructed noise. - log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in - zip(predicted_distributions, structured_generator_inputs)] - loss = -1 * losses.compute_weighted_loss( - log_probs, weights, scope, loss_collection=loss_collection, - reduction=reduction) + with ops.name_scope(scope, 'mutual_information_loss') as scope: + # Calculate the negative log-likelihood of the reconstructed noise. + log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in + zip(predicted_distributions, structured_generator_inputs)] + loss = -1 * losses.compute_weighted_loss( + log_probs, weights, scope, loss_collection=loss_collection, + reduction=reduction) - if add_summaries: - summary.scalar('mutual_information_penalty', loss) + if add_summaries: + summary.scalar('mutual_information_penalty', loss) return loss diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 3e003dd0f808f80dcc486e78e8e101ac6f198947..b5cd8c92ba180e981e0faf877021cb6d69dc34b4 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -274,8 +274,8 @@ class ACGANLossTest(test.TestCase): self._discriminator_real_classification_logits, 'one_hot_labels': self._one_hot_labels, } - self._generator_loss_name = 'softmax_cross_entropy_loss/value' - self._discriminator_loss_name = 'add' + self._generator_loss_name = 'acgan_generator_loss/value' + self._discriminator_loss_name = 'acgan_discriminator_loss/add' self._expected_g_loss = 3.84974 self._expected_d_loss = 9.43950 @@ -453,10 +453,11 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): 'discriminator_scope': self._scope, } self._expected_loss = 9.00000 - self._expected_op_name = 'weighted_loss/value' + self._expected_op_name = 'wasserstein_gradient_penalty/value' self._batch_size = 1 def _discriminator_fn(self, inputs, _): + ops.add_to_collection('fake_update_ops', constant_op.constant(1.0)) return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs def test_loss_with_placeholder(self): @@ -487,6 +488,26 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): self.assertEqual( num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) + def test_works_with_get_collection(self): + """Tests that gradient penalty works inside other scopes.""" + # We ran the discriminator once in the setup, so there should be an op + # already in the collection. + self.assertEqual(1, len(ops.get_collection( + 'fake_update_ops', self._kwargs['discriminator_scope'].name))) + + # Make sure the op is added to the collection even if it's in a name scope. + with ops.name_scope('loss'): + tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) + self.assertEqual(2, len(ops.get_collection( + 'fake_update_ops', self._kwargs['discriminator_scope'].name))) + + # Make sure the op is added to the collection even if it's in a variable + # scope. + with variable_scope.variable_scope('loss_vscope'): + tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) + self.assertEqual(3, len(ops.get_collection( + 'fake_update_ops', self._kwargs['discriminator_scope'].name))) + class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): """Tests for mutual_information_penalty.""" @@ -504,7 +525,7 @@ class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): 'predicted_distributions': self._predicted_distributions, } self._expected_loss = 1.61610 - self._expected_op_name = 'mul' + self._expected_op_name = 'mutual_information_loss/mul' self._batch_size = 2 diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index fca8063891fe53cb9a384fe6908eb6b1c61b90d7..b341f03a0ddaacca8b036189516c71908bee50eb 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -14,10 +14,41 @@ # ============================================================================== """TFGAN utilities for loss functions that accept GANModel namedtuples. -Example: +The losses and penalties in this file all correspond to losses in +`losses_impl.py`. Losses in that file take individual arguments, whereas in this +file they take a `GANModel` tuple. For example: + +losses_impl.py: + ```python + def wasserstein_discriminator_loss( + discriminator_real_outputs, + discriminator_gen_outputs, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False) + ``` + +tuple_losses_impl.py: + ```python + def wasserstein_discriminator_loss( + gan_model, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False) + ``` + + + +Example usage: ```python - # `tfgan.losses.args` losses take individual arguments. - w_loss = tfgan.losses.args.wasserstein_discriminator_loss( + # `tfgan.losses.wargs` losses take individual arguments. + w_loss = tfgan.losses.wargs.wasserstein_discriminator_loss( discriminator_real_outputs, discriminator_gen_outputs) diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index a99e3fbec8dc2a07030aa9356be2b05cfb689b8e..48f5e8e47dbcd5d32c23806b967a0d1e7403d2f7 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Named tuples for TFGAN.""" +"""Named tuples for TFGAN. + +TFGAN training occurs in four steps, and each step communicates with the next +step via one of these named tuples. At each step, you can either use a TFGAN +helper function in `train.py`, or you can manually construct a tuple. +""" from __future__ import absolute_import from __future__ import division @@ -115,7 +120,7 @@ class GANLoss( """GANLoss contains the generator and discriminator losses. Args: - generator_loss: A tensor for the generator loss.. + generator_loss: A tensor for the generator loss. discriminator_loss: A tensor for the discriminator loss. """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index cdc4d78e5b235bbbe53cf2717ac48a156fa96845..06dd281489be7b12d9123ca83d926bc7b81f7e10 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -14,7 +14,17 @@ # ============================================================================== """The TFGAN project provides a lightweight GAN training/testing framework. -See examples in `tensorflow_models` for details on how to use. +This file contains the core helper functions to create and train a GAN model. +See the README or examples in `tensorflow_models` for details on how to use. + +TFGAN training occurs in four steps: +1) Create a model +2) Add a loss +3) Create train ops +4) Run the train ops + +The functions in this file are organized around these four steps. Each function +corresponds to one of the steps. """ from __future__ import absolute_import @@ -51,16 +61,6 @@ __all__ = [ ] -def _convert_tensor_or_l_or_d(tensor_or_l_or_d): - """Convert input, list of inputs, or dictionary of inputs to Tensors.""" - if isinstance(tensor_or_l_or_d, (list, tuple)): - return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d] - elif isinstance(tensor_or_l_or_d, dict): - return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()} - else: - return ops.convert_to_tensor(tensor_or_l_or_d) - - def gan_model( # Lambdas defining models. generator_fn, @@ -133,20 +133,6 @@ def gan_model( discriminator_fn) -def _validate_distributions(distributions_l, noise_l): - if not isinstance(distributions_l, (tuple, list)): - raise ValueError('`predicted_distributions` must be a list. Instead, found ' - '%s.' % type(distributions_l)) - for dist in distributions_l: - if not isinstance(dist, ds.Distribution): - raise ValueError('Every element in `predicted_distributions` must be a ' - '`tf.Distribution`. Instead, found %s.' % type(dist)) - if len(distributions_l) != len(noise_l): - raise ValueError('Length of `predicted_distributions` %i must be the same ' - 'as the length of structured noise %i.' % - (len(distributions_l), len(noise_l))) - - def infogan_model( # Lambdas defining models. generator_fn, @@ -231,16 +217,6 @@ def infogan_model( predicted_distributions) -def _validate_acgan_discriminator_outputs(discriminator_output): - try: - a, b = discriminator_output - except (TypeError, ValueError): - raise TypeError( - 'A discriminator function for ACGAN must output a tuple ' - 'consisting of (discrimination logits, classification logits).') - return a, b - - def acgan_model( # Lambdas defining models. generator_fn, @@ -252,6 +228,7 @@ def acgan_model( # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator', + # Options. check_shapes=True): """Returns an ACGANModel contains all the pieces needed for ACGAN training. @@ -497,11 +474,10 @@ def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True): def gan_train_ops( - model, # GANModel - loss, # GANLoss + model, + loss, generator_optimizer, discriminator_optimizer, - # Optional check flags. check_for_unused_update_ops=True, # Optional args to pass directly to the `create_train_op`. **kwargs): @@ -801,3 +777,40 @@ def get_sequential_train_steps( return gen_loss + dis_loss, should_stop return sequential_train_steps + + +# Helpers + + +def _convert_tensor_or_l_or_d(tensor_or_l_or_d): + """Convert input, list of inputs, or dictionary of inputs to Tensors.""" + if isinstance(tensor_or_l_or_d, (list, tuple)): + return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d] + elif isinstance(tensor_or_l_or_d, dict): + return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()} + else: + return ops.convert_to_tensor(tensor_or_l_or_d) + + +def _validate_distributions(distributions_l, noise_l): + if not isinstance(distributions_l, (tuple, list)): + raise ValueError('`predicted_distributions` must be a list. Instead, found ' + '%s.' % type(distributions_l)) + for dist in distributions_l: + if not isinstance(dist, ds.Distribution): + raise ValueError('Every element in `predicted_distributions` must be a ' + '`tf.Distribution`. Instead, found %s.' % type(dist)) + if len(distributions_l) != len(noise_l): + raise ValueError('Length of `predicted_distributions` %i must be the same ' + 'as the length of structured noise %i.' % + (len(distributions_l), len(noise_l))) + + +def _validate_acgan_discriminator_outputs(discriminator_output): + try: + a, b = discriminator_output + except (TypeError, ValueError): + raise TypeError( + 'A discriminator function for ACGAN must output a tuple ' + 'consisting of (discrimination logits, classification logits).') + return a, b diff --git a/tensorflow/contrib/graph_editor/reroute.py b/tensorflow/contrib/graph_editor/reroute.py index 42968ae63b769f7cea7385933fbadb0782cc86f3..7ffdbb7139281734917fdb715601b317eb58b82f 100644 --- a/tensorflow/contrib/graph_editor/reroute.py +++ b/tensorflow/contrib/graph_editor/reroute.py @@ -397,27 +397,57 @@ def swap_inputs(sgv0, sgv1): def reroute_inputs(sgv0, sgv1): - """Re-route all the inputs of sgv0 to sgv1 (see reroute_inputs).""" + """Re-route all the inputs of two subgraphs. + + Args: + sgv0: the first subgraph to have its inputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + sgv1: the second subgraph to have its inputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + Returns: + A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. + Note that the function argument sgv0 and sgv1 are also modified in place. + Raises: + StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using + the same rules than the function subgraph.make_view. + """ return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.a2b) def swap_outputs(sgv0, sgv1): - """Swap all the outputs of sgv0 and sgv1 (see _reroute_outputs).""" + """Swap all the outputs of sgv0 and sgv1 (see reroute_outputs).""" return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.swap) def reroute_outputs(sgv0, sgv1): - """Re-route all the outputs of sgv0 to sgv1 (see _reroute_outputs).""" + """Re-route all the outputs of two operations. + + Args: + sgv0: the first subgraph to have its outputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + sgv1: the second subgraph to have its outputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + Returns: + A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped. + Note that the function argument sgv0 and sgv1 are also modified in place. + Raises: + StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using + the same rules than the function subgraph.make_view. + """ return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.a2b) def swap_ios(sgv0, sgv1): - """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute).""" + """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute_sgv).""" return _reroute_sgv(sgv0, sgv1, _RerouteMode.swap) def reroute_ios(sgv0, sgv1): - """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute).""" + """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute_sgv).""" return _reroute_sgv(sgv0, sgv1, _RerouteMode.a2b) diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index ab5776b9dd66bb082e9ca3922e8902bfebe6b0b8..ca00394388f67e2ed9508684a47b23c3ee9e79e8 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -191,14 +191,14 @@ class TransformTest(test.TestCase): # Extract the operations. replacement_ts = {w.value(): g} original_mul1_grad = (ops.get_default_graph(). - get_operation_by_name("grad/mul1_grad/mul_1")) + get_operation_by_name("grad/mul1_grad/Mul_1")) # Should not raise exception. res = ge.graph_replace(g, replacement_ts, dst_scope="res") # Extract the operations after graph_replace. result_mul1_grad = (ops.get_default_graph(). - get_operation_by_name("res/grad/mul1_grad/mul_1")) + get_operation_by_name("res/grad/mul1_grad/Mul_1")) # Make sure _original_ops are as expected. self.assertEquals(original_mul1_grad._original_op.name, u"mul1") diff --git a/tensorflow/contrib/hooks/BUILD b/tensorflow/contrib/hooks/BUILD index d81e868d4a922698e4755733b999112088fa2a0b..1576c9ec9b3e058091fd7db865c0368b53d9d3cb 100644 --- a/tensorflow/contrib/hooks/BUILD +++ b/tensorflow/contrib/hooks/BUILD @@ -19,26 +19,6 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", - "//tensorflow/python:platform", - "//tensorflow/python:training", - "//tensorflow/python:util", - ], -) - -py_test( - name = "profiler_hook_test", - size = "small", - srcs = ["python/training/profiler_hook_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":hooks", - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform", - "//tensorflow/python:state_ops", "//tensorflow/python:training", ], ) diff --git a/tensorflow/contrib/hooks/python/training/profiler_hook.py b/tensorflow/contrib/hooks/python/training/profiler_hook.py index 35aa25edfde6f2ed7051ed75ff4f53f8732ae76e..6173aa0797138730e79b21bc9a1779d346edab6b 100644 --- a/tensorflow/contrib/hooks/python/training/profiler_hook.py +++ b/tensorflow/contrib/hooks/python/training/profiler_hook.py @@ -12,93 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Additional `SessionRunHook` implementations to complement those in -tensorflow/python/training. - -""" +"""Placeholder of ProfilerHook for backward compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os.path - -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import timeline -from tensorflow.python.platform import gfile -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer -from tensorflow.python.training.session_run_hook import SessionRunArgs -from tensorflow.python.training import session_run_hook -from tensorflow.python.training import training_util - - -class ProfilerHook(session_run_hook.SessionRunHook): - """Captures CPU/GPU profiling information every N steps or seconds. - - This produces files called "timeline-.json", which are in Chrome - Trace format. - - For more information see: - https://github.com/catapult-project/catapult/blob/master/tracing/README.md""" - - def __init__(self, - save_steps=None, - save_secs=None, - output_dir="", - show_dataflow=True, - show_memory=False): - """Initializes a hook that takes periodic profiling snapshots. - - Args: - save_steps: `int`, save profile traces every N steps. Exactly one of - `save_secs` and `save_steps` should be set. - save_secs: `int`, save profile traces every N seconds. - output_dir: `string`, the directory to save the profile traces to. - Defaults to the current directory. - show_dataflow: `bool`, if True, add flow events to the trace connecting - producers and consumers of tensors. - show_memory: `bool`, if True, add object snapshot events to the trace - showing the sizes and lifetimes of tensors. - """ - self._output_file = os.path.join(output_dir, "timeline-{}.json") - self._show_dataflow = show_dataflow - self._show_memory = show_memory - self._timer = SecondOrStepTimer(every_secs=save_secs, - every_steps=save_steps) - - def begin(self): - self._next_step = None - self._global_step_tensor = training_util.get_global_step() - if self._global_step_tensor is None: - raise RuntimeError( - "Global step should be created to use ProfilerHook.") - - def before_run(self, run_context): - self._request_summary = ( - self._next_step is None or - self._timer.should_trigger_for_step(self._next_step)) - requests = {"global_step": self._global_step_tensor} - opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) - if self._request_summary else None) - - return SessionRunArgs(requests, options=opts) - - def after_run(self, run_context, run_values): - global_step = run_values.results["global_step"] - - if self._request_summary: - self._timer.update_last_triggered_step(global_step) - self._save(global_step, - self._output_file.format(global_step), - run_values.run_metadata.step_stats) - - self._next_step = global_step + 1 +from tensorflow.python.training import basic_session_run_hooks - def _save(self, step, save_path, step_stats): - logging.info("Saving timeline for %d into '%s'.", step, save_path) - with gfile.Open(save_path, "w") as f: - trace = timeline.Timeline(step_stats) - f.write(trace.generate_chrome_trace_format( - show_dataflow=self._show_dataflow, - show_memory=self._show_memory)) +ProfilerHook = basic_session_run_hooks.ProfilerHook # pylint: disable=invalid-name diff --git a/tensorflow/contrib/hooks/python/training/profiler_hook_test.py b/tensorflow/contrib/hooks/python/training/profiler_hook_test.py deleted file mode 100644 index e7ecb5eb2fcc56f14f3d5babe2c22652159afd76..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/hooks/python/training/profiler_hook_test.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for profiler_hook.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os.path -import shutil -import tempfile - -from tensorflow.contrib.framework.python.ops import variables -from tensorflow.contrib.hooks.python.training import ProfilerHook -from tensorflow.python.framework import ops -from tensorflow.python.ops import state_ops -from tensorflow.python.platform import gfile -from tensorflow.python.platform import test -from tensorflow.python.training import monitored_session - - -class ProfilerHookTest(test.TestCase): - - def setUp(self): - super(ProfilerHookTest, self).setUp() - self.output_dir = tempfile.mkdtemp() - self.graph = ops.Graph() - self.filepattern = os.path.join(self.output_dir, "timeline-*.json") - with self.graph.as_default(): - self.global_step = variables.get_or_create_global_step() - self.train_op = state_ops.assign_add(self.global_step, 1) - - def tearDown(self): - super(ProfilerHookTest, self).tearDown() - shutil.rmtree(self.output_dir, ignore_errors=True) - - def _count_timeline_files(self): - return len(gfile.Glob(self.filepattern)) - - def test_raise_in_both_secs_and_steps(self): - with self.assertRaises(ValueError): - ProfilerHook(save_secs=10, save_steps=20) - - def test_raise_in_none_secs_and_steps(self): - with self.assertRaises(ValueError): - ProfilerHook(save_secs=None, save_steps=None) - - def test_save_secs_saves_in_first_step(self): - with self.graph.as_default(): - hook = ProfilerHook(save_secs=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) - self.assertEqual(1, self._count_timeline_files()) - - @test.mock.patch('time.time') - def test_save_secs_saves_periodically(self, mock_time): - # Pick a fixed start time. - current_time = 1484863632.320497 - - with self.graph.as_default(): - mock_time.return_value = current_time - hook = ProfilerHook(save_secs=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) # Saved. - self.assertEqual(1, self._count_timeline_files()) - sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) - # Simulate 2.5 seconds of sleep. - mock_time.return_value = current_time + 2.5 - sess.run(self.train_op) # Saved. - - # Pretend some small amount of time has passed. - mock_time.return_value = current_time + 0.1 - sess.run(self.train_op) # Not saved. - # Edge test just before we should save the timeline. - mock_time.return_value = current_time + 1.9 - sess.run(self.train_op) # Not saved. - self.assertEqual(2, self._count_timeline_files()) - - mock_time.return_value = current_time + 4.5 - sess.run(self.train_op) # Saved. - self.assertEqual(3, self._count_timeline_files()) - - def test_save_steps_saves_in_first_step(self): - with self.graph.as_default(): - hook = ProfilerHook(save_secs=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) # Saved. - sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) - - def test_save_steps_saves_periodically(self): - with self.graph.as_default(): - hook = ProfilerHook(save_steps=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - self.assertEqual(0, self._count_timeline_files()) - sess.run(self.train_op) # Saved. - self.assertEqual(1, self._count_timeline_files()) - sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) - sess.run(self.train_op) # Saved. - self.assertEqual(2, self._count_timeline_files()) - sess.run(self.train_op) # Not saved. - self.assertEqual(2, self._count_timeline_files()) - sess.run(self.train_op) # Saved. - self.assertEqual(3, self._count_timeline_files()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index a18f14112e469b1cf83a046fa65b87e5c69fb88b..d0600d4668b2943e7fc880a079750d3a59406d68 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -211,6 +211,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":image_py", + ":single_image_random_dot_stereograms_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index 59a322d3ca6e7e53872f8e7e126e30923ddd77a0..d030dffadeb9d67f7ffcbc197a2a3feb9b3b122d 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -26,6 +26,8 @@ projective transforms (including rotation) are supported. @@random_yiq_hsv @@rotate @@transform +@@translate +@@translations_to_projective_transforms @@bipartite_match @@single_image_random_dot_stereograms """ @@ -41,6 +43,8 @@ from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_t from tensorflow.contrib.image.python.ops.image_ops import compose_transforms from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform +from tensorflow.contrib.image.python.ops.image_ops import translate +from tensorflow.contrib.image.python.ops.image_ops import translations_to_projective_transforms from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms import single_image_random_dot_stereograms from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index b8a0706b61449ebebeb2f1dc98b438f9dd620aa3..b50177ae5651fbc15f292e11031411c2074357ec 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -36,8 +36,8 @@ _DTYPES = set( class ImageOpsTest(test_util.TensorFlowTestCase): def test_zeros(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): for shape in [(5, 5), (24, 24), (2, 24, 24, 3)]: for angle in [0, 1, np.pi / 2.0]: image = array_ops.zeros(shape, dtype) @@ -46,8 +46,8 @@ class ImageOpsTest(test_util.TensorFlowTestCase): np.zeros(shape, dtype.as_numpy_dtype())) def test_rotate_even(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): image = array_ops.reshape( math_ops.cast(math_ops.range(36), dtype), (6, 6)) image_rep = array_ops.tile(image[None, :, :, None], [3, 1, 1, 1]) @@ -68,8 +68,8 @@ class ImageOpsTest(test_util.TensorFlowTestCase): [1, 7, 13, 19, 25, 31], [0, 6, 12, 18, 24, 30]]]) def test_rotate_odd(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): image = array_ops.reshape( math_ops.cast(math_ops.range(25), dtype), (5, 5)) image_rep = array_ops.tile(image[None, :, :, None], [3, 1, 1, 1]) @@ -87,9 +87,25 @@ class ImageOpsTest(test_util.TensorFlowTestCase): [22, 17, 12, 7, 2], [23, 18, 13, 8, 3], [24, 19, 14, 9, 4]]]) + def test_translate(self): + for dtype in _DTYPES: + with self.test_session(): + image = constant_op.constant( + [[1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1]], dtype=dtype) + translation = constant_op.constant([-1, -1], dtypes.float32) + image_translated = image_ops.translate(image, translation) + self.assertAllEqual(image_translated.eval(), + [[1, 0, 1, 0], + [0, 1, 0, 0], + [1, 0, 1, 0], + [0, 0, 0, 0]]) + def test_compose(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): image = constant_op.constant( [[1, 1, 1, 0], [1, 0, 0, 0], @@ -246,4 +262,3 @@ class BipartiteMatchTest(test_util.TensorFlowTestCase): if __name__ == "__main__": googletest.main() - diff --git a/tensorflow/contrib/image/python/ops/distort_image_ops.py b/tensorflow/contrib/image/python/ops/distort_image_ops.py index 39f023a2b40a1a8481217fe8fa191a5072e7a3ff..06e8e4ee720d04f4b29a25f833297bb17a7d239c 100644 --- a/tensorflow/contrib/image/python/ops/distort_image_ops.py +++ b/tensorflow/contrib/image/python/ops/distort_image_ops.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.image.ops import gen_distort_image_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -132,7 +133,7 @@ def adjust_hsv_in_yiq(image, orig_dtype = image.dtype flt_image = image_ops.convert_image_dtype(image, dtypes.float32) - rgb_altered = _distort_image_ops.adjust_hsv_in_yiq( + rgb_altered = gen_distort_image_ops.adjust_hsv_in_yiq( flt_image, delta_hue, scale_saturation, scale_value) return image_ops.convert_image_dtype(rgb_altered, orig_dtype) diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index aef3e385b57486d5cb3cb13d9e8b9519768abd7c..011ddeaa9a1eebaa507c9e0d33f9546ff3497166 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -37,16 +37,18 @@ _IMAGE_DTYPES = set( ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) -def rotate(images, angles, interpolation="NEAREST"): +def rotate(images, angles, interpolation="NEAREST", name=None): """Rotate image(s) by the passed angle(s) in radians. Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). + (num_rows, num_columns) (HW). The rank must be statically known (the + shape is not `TensorShape(None)`. angles: A scalar angle to rotate all images by, or (if images has rank 4) a vector of length num_images, with an angle for each image in the batch. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". + name: The name of the op. Returns: Image(s) with the same type and shape as `images`, rotated by the given @@ -55,38 +57,77 @@ def rotate(images, angles, interpolation="NEAREST"): Raises: TypeError: If `image` is an invalid type. """ - image_or_images = ops.convert_to_tensor(images, name="images") - if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: - raise TypeError("Invalid dtype %s." % image_or_images.dtype) - if len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") - - image_height = math_ops.cast(array_ops.shape(images)[1], dtypes.float32)[None] - image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None] - output = transform( - images, - angles_to_projective_transforms(angles, image_height, image_width), - interpolation=interpolation) - if len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + with ops.name_scope(name, "rotate"): + image_or_images = ops.convert_to_tensor(images) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + elif image_or_images.get_shape().ndims is None: + raise TypeError("image_or_images rank must be statically known") + elif len(image_or_images.get_shape()) == 2: + images = image_or_images[None, :, :, None] + elif len(image_or_images.get_shape()) == 3: + images = image_or_images[None, :, :, :] + elif len(image_or_images.get_shape()) == 4: + images = image_or_images + else: + raise TypeError("Images should have rank between 2 and 4.") + + image_height = math_ops.cast(array_ops.shape(images)[1], + dtypes.float32)[None] + image_width = math_ops.cast(array_ops.shape(images)[2], + dtypes.float32)[None] + output = transform( + images, + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation) + if image_or_images.get_shape().ndims is None: + raise TypeError("image_or_images rank must be statically known") + elif len(image_or_images.get_shape()) == 2: + return output[0, :, :, 0] + elif len(image_or_images.get_shape()) == 3: + return output[0, :, :, :] + else: + return output + + +def translate(images, translations, interpolation="NEAREST", name=None): + """Translate image(s) by the passed vectors(s). + Args: + images: A tensor of shape (num_images, num_rows, num_columns, num_channels) + (NHWC), (num_rows, num_columns, num_channels) (HWC), or + (num_rows, num_columns) (HW). The rank must be statically known (the + shape is not `TensorShape(None)`. + translations: A vector representing [dx, dy] or (if images has rank 4) + a matrix of length num_images, with a [dx, dy] vector for each image in + the batch. + interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". + name: The name of the op. -def angles_to_projective_transforms(angles, image_height, image_width): + Returns: + Image(s) with the same type and shape as `images`, translated by the given + vector(s). Empty space due to the translation will be filled with zeros. + + Raises: + TypeError: If `image` is an invalid type. + """ + with ops.name_scope(name, "translate"): + return transform( + images, + translations_to_projective_transforms(translations), + interpolation=interpolation) + + +def angles_to_projective_transforms(angles, + image_height, + image_width, + name=None): """Returns projective transform(s) for the given angle(s). Args: angles: A scalar angle to rotate all images by, or (for batches of images) - a vector with an angle to rotate each image in the batch. + a vector with an angle to rotate each image in the batch. The rank must + be statically known (the shape is not `TensorShape(None)`. image_height: Height of the image(s) to be transformed. image_width: Width of the image(s) to be transformed. @@ -94,41 +135,89 @@ def angles_to_projective_transforms(angles, image_height, image_width): A tensor of shape (num_images, 8). Projective transforms which can be given to `tf.contrib.image.transform`. """ - angle_or_angles = ops.convert_to_tensor( - angles, name="angles", dtype=dtypes.float32) - if len(angle_or_angles.get_shape()) == 0: # pylint: disable=g-explicit-length-test - angles = angle_or_angles[None] - elif len(angle_or_angles.get_shape()) == 1: - angles = angle_or_angles - else: - raise TypeError("Angles should have rank 0 or 1.") - x_offset = ((image_width - 1) - (math_ops.cos(angles) * - (image_width - 1) - math_ops.sin(angles) * - (image_height - 1))) / 2.0 - y_offset = ((image_height - 1) - (math_ops.sin(angles) * - (image_width - 1) + math_ops.cos(angles) * - (image_height - 1))) / 2.0 - num_angles = array_ops.shape(angles)[0] - return array_ops.concat( - values=[ - math_ops.cos(angles)[:, None], - -math_ops.sin(angles)[:, None], - x_offset[:, None], - math_ops.sin(angles)[:, None], - math_ops.cos(angles)[:, None], - y_offset[:, None], - array_ops.zeros((num_angles, 2), dtypes.float32), - ], - axis=1) - - -def transform(images, transforms, interpolation="NEAREST"): + with ops.name_scope(name, "angles_to_projective_transforms"): + angle_or_angles = ops.convert_to_tensor( + angles, name="angles", dtype=dtypes.float32) + if len(angle_or_angles.get_shape()) == 0: # pylint: disable=g-explicit-length-test + angles = angle_or_angles[None] + elif len(angle_or_angles.get_shape()) == 1: + angles = angle_or_angles + else: + raise TypeError("Angles should have rank 0 or 1.") + x_offset = ((image_width - 1) - (math_ops.cos(angles) * + (image_width - 1) - math_ops.sin(angles) * + (image_height - 1))) / 2.0 + y_offset = ((image_height - 1) - (math_ops.sin(angles) * + (image_width - 1) + math_ops.cos(angles) * + (image_height - 1))) / 2.0 + num_angles = array_ops.shape(angles)[0] + return array_ops.concat( + values=[ + math_ops.cos(angles)[:, None], + -math_ops.sin(angles)[:, None], + x_offset[:, None], + math_ops.sin(angles)[:, None], + math_ops.cos(angles)[:, None], + y_offset[:, None], + array_ops.zeros((num_angles, 2), dtypes.float32), + ], + axis=1) + + +def translations_to_projective_transforms(translations, name=None): + """Returns projective transform(s) for the given translation(s). + + Args: + translations: A 2-element list representing [dx, dy] or a matrix of + 2-element lists representing [dx, dy] to translate for each image + (for a batch of images). The rank must be statically known (the shape + is not `TensorShape(None)`. + name: The name of the op. + + Returns: + A tensor of shape (num_images, 8) projective transforms which can be given + to `tf.contrib.image.transform`. + """ + with ops.name_scope(name, "translations_to_projective_transforms"): + translation_or_translations = ops.convert_to_tensor( + translations, name="translations", dtype=dtypes.float32) + if translation_or_translations.get_shape().ndims is None: + raise TypeError( + "translation_or_translations rank must be statically known") + elif len(translation_or_translations.get_shape()) == 1: + translations = translation_or_translations[None] + elif len(translation_or_translations.get_shape()) == 2: + translations = translation_or_translations + else: + raise TypeError("Translations should have rank 1 or 2.") + num_translations = array_ops.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # Translation matrices are always float32. + return array_ops.concat( + values=[ + array_ops.ones((num_translations, 1), dtypes.float32), + array_ops.zeros((num_translations, 1), dtypes.float32), + -translations[:, 0, None], + array_ops.zeros((num_translations, 1), dtypes.float32), + array_ops.ones((num_translations, 1), dtypes.float32), + -translations[:, 1, None], + array_ops.zeros((num_translations, 2), dtypes.float32), + ], + axis=1) + + +def transform(images, transforms, interpolation="NEAREST", name=None): """Applies the given transform(s) to the image(s). Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). + (num_rows, num_columns) (HW). The rank must be statically known (the + shape is not `TensorShape(None)`. transforms: Projective transform matrix/matrices. A vector of length 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point @@ -146,34 +235,40 @@ def transform(images, transforms, interpolation="NEAREST"): Raises: TypeError: If `image` is an invalid type. """ - image_or_images = ops.convert_to_tensor(images, name="images") - transform_or_transforms = ops.convert_to_tensor( - transforms, name="transforms", dtype=dtypes.float32) - if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: - raise TypeError("Invalid dtype %s." % image_or_images.dtype) - if len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") - - if len(transform_or_transforms.get_shape()) == 1: - transforms = transform_or_transforms[None] - elif len(transform_or_transforms.get_shape()) == 2: - transforms = transform_or_transforms - else: - raise TypeError("Transforms should have rank 1 or 2.") - output = gen_image_ops.image_projective_transform( - images, transforms, interpolation=interpolation.upper()) - if len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + with ops.name_scope(name, "transform"): + image_or_images = ops.convert_to_tensor(images, name="images") + transform_or_transforms = ops.convert_to_tensor( + transforms, name="transforms", dtype=dtypes.float32) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + elif image_or_images.get_shape().ndims is None: + raise TypeError("image_or_images rank must be statically known") + elif len(image_or_images.get_shape()) == 2: + images = image_or_images[None, :, :, None] + elif len(image_or_images.get_shape()) == 3: + images = image_or_images[None, :, :, :] + elif len(image_or_images.get_shape()) == 4: + images = image_or_images + else: + raise TypeError("Images should have rank between 2 and 4.") + + if len(transform_or_transforms.get_shape()) == 1: + transforms = transform_or_transforms[None] + elif transform_or_transforms.get_shape().ndims is None: + raise TypeError( + "transform_or_transforms rank must be statically known") + elif len(transform_or_transforms.get_shape()) == 2: + transforms = transform_or_transforms + else: + raise TypeError("Transforms should have rank 1 or 2.") + output = gen_image_ops.image_projective_transform( + images, transforms, interpolation=interpolation.upper()) + if len(image_or_images.get_shape()) == 2: + return output[0, :, :, 0] + elif len(image_or_images.get_shape()) == 3: + return output[0, :, :, :] + else: + return output def compose_transforms(*transforms): @@ -191,11 +286,12 @@ def compose_transforms(*transforms): order. """ assert transforms, "transforms cannot be empty" - composed = _flat_transforms_to_matrices(transforms[0]) - for tr in transforms[1:]: - # Multiply batches of matrices. - composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr)) - return _transform_matrices_to_flat(composed) + with ops.name_scope("compose_transforms"): + composed = _flat_transforms_to_matrices(transforms[0]) + for tr in transforms[1:]: + # Multiply batches of matrices. + composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr)) + return _transform_matrices_to_flat(composed) def _flat_transforms_to_matrices(transforms): @@ -211,8 +307,8 @@ def _flat_transforms_to_matrices(transforms): def _transform_matrices_to_flat(transform_matrices): # Flatten each matrix. - transforms = array_ops.reshape( - transform_matrices, constant_op.constant([-1, 9])) + transforms = array_ops.reshape(transform_matrices, + constant_op.constant([-1, 9])) # Divide each matrix by the last entry (normally 1). transforms /= transforms[:, 8:9] return transforms[:, :8] @@ -260,10 +356,10 @@ def _image_projective_transform_grad(op, grad): return [output, None] -def bipartite_match( - distance_mat, - num_valid_rows, - top_k=-1): +def bipartite_match(distance_mat, + num_valid_rows, + top_k=-1, + name="bipartite_match"): """Find bipartite matching based on a given distance matrix. A greedy bi-partite matching algorithm is used to obtain the matching with @@ -282,6 +378,7 @@ def bipartite_match( top_k: A scalar that specifies the number of top-k matches to retrieve. If set to be negative, then is set according to the maximum number of matches from `distance_mat`. + name: The name of the op. Returns: row_to_col_match_indices: A vector of length num_rows, which is the number @@ -292,7 +389,8 @@ def bipartite_match( If `col_to_row_match_indices[j]` is not -1, column j is matched to row `col_to_row_match_indices[j]`. """ - result = gen_image_ops.bipartite_match(distance_mat, num_valid_rows, top_k) + result = gen_image_ops.bipartite_match( + distance_mat, num_valid_rows, top_k, name=name) return result diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index 79261c5e7501566537ee9492b5aa64570599e862..5cccf26028ca6bf269dbc67a33075351edecb407 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.image.ops import gen_single_image_random_dot_stereograms_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader @@ -107,7 +108,7 @@ def single_image_random_dot_stereograms( 'depth_values' """ - result = _sirds_ops.single_image_random_dot_stereograms( + result = gen_single_image_random_dot_stereograms_ops.single_image_random_dot_stereograms( # pylint: disable=line-too-long depth_values=depth_values, hidden_surface_removal=hidden_surface_removal, convergence_dots_size=convergence_dots_size, diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 4d00b8536ee1e97921394343c29c0fdb4694f43c..762a2f0b57e95e2fef3dd177070701afb410e93a 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -7,6 +7,78 @@ faster in `>14x` fewer iterations than SGD with Momentum. [kfac-paper]: https://arxiv.org/abs/1503.05671 +## What is K-FAC? + +K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation +to the [Natural Gradient][natural_gradient] algorithm designed specifically for +neural networks. It maintains a block-diagonal approximation to the [Fisher +Information matrix][fisher_information], whose inverse preconditions the +gradient. + +K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations. +Experimentally, K-FAC converges `>3.5x` faster than well-tuned SGD. + +Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What +are the weights for layer i?"). As such, you must add some additional code while +constructing your model to use K-FAC. + +[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746 +[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form + +## Why should I use K-FAC? + +K-FAC can take advantage of the curvature of the optimization problem, resulting +in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same +loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See how +training loss changes as a function of number of epochs, steps, and seconds: + +![autoencoder](g3doc/autoencoder.png) + +## Is K-FAC for me? + +If you have a feedforward or convolutional model for classification that is +converging too slowly, K-FAC is for you. K-FAC can be used in your model if: + +* Your model defines a posterior distribution. +* Your model uses only fully-connected or convolutional layers (residual + connections OK). +* You are training on CPU or GPU. +* You can modify model code to register layers with K-FAC. + +## How do I use K-FAC? + +Using K-FAC requires three steps: + +1. Registering layer inputs, weights, and pre-activations with a + `LayerCollection`. +1. Minimizing the loss with a `KfacOptimizer`. +1. Keeping K-FAC's preconditioner updated. + +```python +# Build model. +w = tf.get_variable("w", ...) +b = tf.get_variable("b", ...) +logits = tf.matmul(x, w) + b +loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)) + +# Register layers. +layer_collection = LayerCollection() +layer_collection.register_fully_connected((w, b), x, logits) +layer_collection.register_categorical_predictive_distribution(logits) + +# Construct training ops. +optimizer = KfacOptimizer(..., layer_collection=layer_collection) +train_op = optimizer.minimize(loss) + +# Minimize loss. +with tf.Session() as sess: + ... + sess.run([train_op, optimizer.cov_update_op, optimizer.inv_update_op]) +``` + +See [`examples/`](https://www.tensorflow.org/code/tensorflow/contrib/kfac/examples/) for runnable, end-to-end illustrations. + ## Authors - Alok Aggarwal diff --git a/tensorflow/contrib/kfac/g3doc/autoencoder.png b/tensorflow/contrib/kfac/g3doc/autoencoder.png new file mode 100644 index 0000000000000000000000000000000000000000..20f93c77034f3355653a6a260cccdad29c080eaf Binary files /dev/null and b/tensorflow/contrib/kfac/g3doc/autoencoder.png differ diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py index a3b95c9b376da2de8eed2d4c08dbc635f3aeb9b6..21b5cde9b931a95110c9a5fd7930a3a4ee74b207 100644 --- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py +++ b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py @@ -36,17 +36,17 @@ class CurvatureMatrixVectorProductComputer(object): For example, the Fisher associated with a log-prob loss w.r.t. the parameters. - The vecs argument to each method are lists of tensors that must be the + The 'vecs' argument to each method are lists of tensors that must be the size as the corresponding ones from "wrt_tensors". They represent the vector being multiplied. "factors" of the matrix M are defined as matrices B such that B*B^T = M. - Methods that multiply by the factor B take a "loss_inner_vecs" argument - instead of vecs, which must be a list of tensors with shapes given by the + Methods that multiply by the factor B take a 'loss_inner_vecs' argument + instead of 'vecs', which must be a list of tensors with shapes given by the corresponding XXX_inner_shapes property. Note that matrix-vector products are not normalized by the batch size, nor - are any damping terms added to the results. These things can easily be + are any damping terms added to the results. These things can be easily applied externally, if desired. See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf @@ -61,7 +61,8 @@ class CurvatureMatrixVectorProductComputer(object): Args: losses: A list of LossFunction instances whose sum defines the total loss. wrt_tensors: A list of Tensors to compute the differential quantities - defining the matrices with respect to (see class description). + (defining the matrices) with respect to. See class description for more + info. """ self._losses = losses self._inputs_to_losses = list(loss.inputs for loss in losses) @@ -73,24 +74,23 @@ class CurvatureMatrixVectorProductComputer(object): return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses)) # Jacobian multiplication functions: - # NOTE: These implementations use tf.gradients and thus aren't actually - # computing partial derivatives, but total derivatives instead (despite what - # the documentation for tf.gradients says). Because we require partial - # derivatives for Jacobians this implementation will only be correct if the - # partial derivatives are equal to the full derivatives. This happens as long - # as the elements of wrt_tensors don't depend on each other in the graph. If - # these tensors are standard neural network parameters this will be true. def _multiply_jacobian(self, vecs): """Multiply vecs by the Jacobian of losses.""" + # We stop gradients at wrt_tensors to produce partial derivatives (which is + # what we want for Jacobians). jacobian_vecs_flat = utils.fwd_gradients( - self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs) + self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs, + stop_gradients=self._wrt_tensors) return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat) def _multiply_jacobian_transpose(self, loss_vecs): """Multiply vecs by the transpose Jacobian of losses.""" loss_vecs_flat = nest.flatten(loss_vecs) + # We stop gradients at wrt_tensors to produce partial derivatives (which is + # what we want for Jacobians). return gradients_impl.gradients( - self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat) + self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat, + stop_gradients=self._wrt_tensors) # Losses Fisher/Hessian multiplication functions: def _multiply_loss_fisher(self, loss_vecs): diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 3bae45b32402c3ea60f3a82b99580d90dc150f86..9d8bb8c8ceac691f5ce20938dea696a4b6dbcd42 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -34,6 +34,14 @@ from tensorflow.python.ops import math_ops NORMALIZE_DAMPING_POWER = 1.0 +def set_global_constants(normalize_damping_power=None): + """Sets various global constants used by the classes in this module.""" + global NORMALIZE_DAMPING_POWER + + if normalize_damping_power is not None: + NORMALIZE_DAMPING_POWER = normalize_damping_power + + @six.add_metaclass(abc.ABCMeta) class FisherBlock(object): """Abstract base class for objects modeling approximate Fisher matrix blocks. diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py index c6cc169b3784ca2e60cde6cd703f13ddeaaad985..59389f8d385c18f50914d690cfaa2825ef807ed3 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -31,7 +31,8 @@ _allowed_symbols = [ 'KroneckerProductFB', 'FullyConnectedKFACBasicFB', 'ConvKFCBasicFB', - 'ConvDiagonalFB' + 'ConvDiagonalFB', + 'set_global_constants', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 3d14cf1ead32376b2cc79a1269110bd80d253e81..d3c783ee2f343ce3afbdcf34b6f237e21929a3a0 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -33,7 +33,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages - # Whether to initialize covariance estimators at a zero matrix (or the identity # matrix). INIT_COVARIANCES_AT_ZERO = False @@ -51,6 +50,25 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 EIGENVALUE_CLIPPING_THRESHOLD = 0.0 +def set_global_constants(init_covariances_at_zero=None, zero_debias=None, + eigenvalue_decomposition_threshold=None, + eigenvalue_clipping_threshold=None): + """Sets various global constants used by the classes in this module.""" + global INIT_COVARIANCES_AT_ZERO + global ZERO_DEBIAS + global EIGENVALUE_DECOMPOSITION_THRESHOLD + global EIGENVALUE_CLIPPING_THRESHOLD + + if init_covariances_at_zero is not None: + INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero + if zero_debias is not None: + ZERO_DEBIAS = zero_debias + if eigenvalue_decomposition_threshold is not None: + EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold + if eigenvalue_clipping_threshold is not None: + EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold + + def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument return array_ops.diag(array_ops.ones(shape[0], dtype)) @@ -298,7 +316,7 @@ class InverseProvidingFactor(FisherFactor): self.register_eigendecomp() # ensures self._eigendecomp is set eigenvalues, eigenvectors = self._eigendecomp # pylint: disable=unpacking-non-sequence - # the matrix self._cov is positive semidefinite by construction, but the + # The matrix self._cov is positive semidefinite by construction, but the # numerical eigenvalues could be negative due to numerical errors, so here # we clip them to be at least EIGENVALUE_CLIPPING_THRESHOLD. clipped_eigenvalues = math_ops.maximum(eigenvalues, @@ -421,8 +439,8 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): tuple(outputs_grads)) # Note that we precompute the required operations on the inputs since the - # inputs don't change with the 'idx' argument to _compute_new_cov. Only - # the target entry of _outputs_grads changes with idx. + # inputs don't change with the 'idx' argument to _compute_new_cov. (Only + # the target entry of _outputs_grads changes with idx.) if has_bias: inputs = _append_homog(inputs) self._squared_inputs = math_ops.square(inputs) @@ -484,8 +502,8 @@ class ConvDiagonalFactor(DiagonalFactor): + tuple(outputs_grads)) # Note that we precompute the required operations on the inputs since the - # inputs don't change with the 'idx' argument to _compute_new_cov. Only - # the target entry of _outputs_grads changes with idx. + # inputs don't change with the 'idx' argument to _compute_new_cov. (Only + # the target entry of _outputs_grads changes with idx.) filter_height, filter_width, _, _ = self._filter_shape patches = array_ops.extract_image_patches( inputs, @@ -526,9 +544,8 @@ class ConvDiagonalFactor(DiagonalFactor): def _convdiag_sum_of_squares(self, patches, outputs_grad): # This computes the sum of the squares of the per-training-case "gradients". - # It does this simply by computing a giant tensor containing all of these - # them, doing an entry-wise square, and them summing along the batch - # dimension. + # It does this simply by computing a giant tensor containing all of these, + # doing an entry-wise square, and them summing along the batch dimension. case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches, outputs_grad) return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py index 49a07b15986b946105d32a1950bcccabaa363cef..23ee93cd405bbf719939df89d525c812ee061f8b 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -40,6 +40,7 @@ _allowed_symbols = [ "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor", "ConvDiagonalFactor", + "set_global_constants", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index d80382b9cf31d784d7d2267a18cf88362fea95fc..979a4fd1de8f612a440f41f5ba0275c12bb3fce0 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -104,7 +104,7 @@ class LossFunction(object): @abc.abstractmethod def multiply_hessian_factor_transpose(self, vector): - """Right-multiply a vector by the tranpose of a factor B of the Hessian. + """Right-multiply a vector by the transpose of a factor B of the Hessian. Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) of the loss function with respect to its inputs. Typically this will be @@ -218,7 +218,7 @@ class NegativeLogProbLoss(LossFunction): @abc.abstractmethod def multiply_fisher_factor_transpose(self, vector): - """Right-multiply a vector by the tranpose of a factor B of the Fisher. + """Right-multiply a vector by the transpose of a factor B of the Fisher. Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying @@ -397,7 +397,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): This class parameterizes a multivariate normal distribution with n independent dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not - assume the variance is held constant. The Fisher Information for for n = 1 + assume the variance is held constant. The Fisher Information for n = 1 is given by, F = [[1 / variance, 0], diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py index 0617c5be4d9fc0412245bb12f5d1ff2ca3ee420a..831870fca451c585cb1a1dc6b24aad757e2bbaa8 100644 --- a/tensorflow/contrib/kfac/python/ops/op_queue.py +++ b/tensorflow/contrib/kfac/python/ops/op_queue.py @@ -61,7 +61,7 @@ class OpQueue(object): sess: tf.Session. Returns: - Next Op chosen from from 'ops'. + Next Op chosen from 'ops'. """ # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii') # returns a str. diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index b34b4e10adb549990b63e9726a88294d03ecb59a..a7473481e44da0b09c047db9af29032918ea6cef 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -250,7 +250,7 @@ def generate_random_signs(shape, dtype=dtypes.float32): return 2 * math_ops.cast(ints, dtype=dtype) - 1 -def fwd_gradients(ys, xs, grad_xs=None): +def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): """Compute forward-mode gradients.""" # See b/37888268. @@ -260,7 +260,8 @@ def fwd_gradients(ys, xs, grad_xs=None): # generated by the first gradients_impl.gradients call. us = [array_ops.zeros_like(y) + float("nan") for y in ys] - dydxs = gradients_impl.gradients(ys, xs, grad_ys=us) + dydxs = gradients_impl.gradients(ys, xs, grad_ys=us, + stop_gradients=stop_gradients) # Deal with strange types that gradients_impl.gradients returns but can't # deal with. diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index a01baea9cc98f74984678e2072463a38997972ed..29ab281b1a603df153619eed2336420ddde9f6a8 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1732,13 +1732,14 @@ class GDN(base.Layer): trainable=True, name=None, **kwargs): - super(GDN, self).__init__(trainable=trainable, name=name, **kwargs) + super(GDN, self).__init__(trainable=trainable, name=name, + activity_regularizer=activity_regularizer, + **kwargs) self.inverse = inverse self._beta_min = beta_min self._gamma_init = gamma_init self._reparam_offset = reparam_offset self.data_format = data_format - self.activity_regularizer = activity_regularizer self._channel_axis() # trigger ValueError early self.input_spec = base.InputSpec(min_ndim=3, max_ndim=5) diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 8813a99f1994ade17cca3b1371a17278e434cef9..1ea25bd1a5685eb6f840e621b5739029a660aa0f 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -176,7 +176,7 @@ class OptimizersTest(test.TestCase): session.run(train, feed_dict={x: 5}) var_value, global_step_value = session.run([var, global_step]) # Due to randomness the following number may change if graph is different. - self.assertAlmostEqual(var_value, 8.5591021, 4) + self.assertAlmostEqual(var_value, 9.86912, 4) self.assertEqual(global_step_value, 1) def testGradientNoiseWithClipping(self): @@ -193,7 +193,7 @@ class OptimizersTest(test.TestCase): variables.global_variables_initializer().run() session.run(train, feed_dict={x: 5}) var_value, global_step_value = session.run([var, global_step]) - self.assertAlmostEqual(var_value, 9.0, 4) + self.assertAlmostEqual(var_value, 9.86912, 4) self.assertEqual(global_step_value, 1) def testGradientClip(self): diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index 2fec0508a5603768a301e1e2f9c251a89cf0ef69..12f9bba531a296a00d17956b8ce32e5d7dead380 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -348,6 +348,12 @@ class DNNClassifierTest(test.TestCase): for prediction in predictions: self.assertIn(prediction, (0, 1)) + def _assertClassificationPredictions( + self, expected_len, n_classes, predictions): + self.assertEqual(expected_len, len(predictions)) + for prediction in predictions: + self.assertIn(prediction, range(n_classes)) + def _assertProbabilities(self, expected_batch_size, expected_n_classes, probabilities): self.assertEqual(expected_batch_size, len(probabilities)) @@ -732,7 +738,7 @@ class DNNClassifierTest(test.TestCase): self.assertIn('loss', scores) predicted_classes = classifier.predict_classes( input_fn=_input_fn, as_iterable=False) - self._assertBinaryPredictions(3, predicted_classes) + self._assertClassificationPredictions(3, n_classes, predicted_classes) predictions = classifier.predict(input_fn=_input_fn, as_iterable=False) self.assertAllEqual(predicted_classes, predictions) probabilities = classifier.predict_proba( @@ -765,8 +771,9 @@ class DNNClassifierTest(test.TestCase): feature_column.real_valued_column('age') ] + n_classes = 3 classifier = dnn.DNNClassifier( - n_classes=3, + n_classes=n_classes, feature_columns=feature_columns, hidden_units=[3, 3], config=run_config.RunConfig(tf_random_seed=1)) @@ -780,7 +787,7 @@ class DNNClassifierTest(test.TestCase): predicted_classes = list( classifier.predict_classes( input_fn=predict_input_fn, as_iterable=True)) - self.assertListEqual(predicted_classes, [1, 0, 0]) + self._assertClassificationPredictions(3, n_classes, predicted_classes) predictions = list( classifier.predict( input_fn=predict_input_fn, as_iterable=True)) @@ -788,8 +795,7 @@ class DNNClassifierTest(test.TestCase): predicted_proba = list( classifier.predict_proba( input_fn=predict_input_fn, as_iterable=True)) - self.assertAllClose( - predicted_proba, [[0., 1., 0.], [1., 0., 0.], [1., 0., 0.]], atol=0.3) + self._assertProbabilities(3, n_classes, predicted_proba) def testCustomMetrics(self): """Tests custom evaluation metrics.""" @@ -1214,6 +1220,12 @@ class DNNRegressorTest(test.TestCase): scores = regressor.evaluate(input_fn=_input_fn_eval, steps=1) self.assertIn('loss', scores) + def _assertRegressionOutputs( + self, predictions, expected_shape): + predictions_nparray = np.array(predictions) + self.assertAllEqual(expected_shape, predictions_nparray.shape) + self.assertTrue(np.issubdtype(predictions_nparray.dtype, np.float)) + def testPredict_AsIterableFalse(self): """Tests predict method with as_iterable=False.""" labels = [1., 0., 0.2] @@ -1252,7 +1264,7 @@ class DNNRegressorTest(test.TestCase): self.assertIn('loss', scores) predicted_scores = regressor.predict_scores( input_fn=_input_fn, as_iterable=False) - self.assertAllClose(labels, predicted_scores, atol=0.2) + self._assertRegressionOutputs(predicted_scores, [3]) predictions = regressor.predict(input_fn=_input_fn, as_iterable=False) self.assertAllClose(predicted_scores, predictions) @@ -1296,7 +1308,7 @@ class DNNRegressorTest(test.TestCase): predicted_scores = list( regressor.predict_scores( input_fn=predict_input_fn, as_iterable=True)) - self.assertAllClose(labels, predicted_scores, atol=0.2) + self._assertRegressionOutputs(predicted_scores, [3]) predictions = list( regressor.predict(input_fn=predict_input_fn, as_iterable=True)) self.assertAllClose(predicted_scores, predictions) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 1724d7599d09873f969555cc9382c0753eba463f..69440e823ef1ed2d739f28bc14587891f2de80bb 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -639,7 +639,7 @@ class DynamicRnnEstimator(estimator.Estimator): ValueError: `problem_type` is not one of `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`. ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but - `num_classes` is not specifieProblemType + `num_classes` is not specified. ValueError: `prediction_type` is not one of `PredictionType.MULTIPLE_VALUE` or `PredictionType.SINGLE_VALUE`. """ diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 719e5da21df57ed778fe6aee3fe57f3b202dfaa2..468d792a0dccf5cf046a41ed8e1600940a15ac37 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -33,7 +33,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import logging_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib @@ -635,10 +634,11 @@ def _create_model_fn_ops(features, if (mode != model_fn.ModeKeys.INFER) and (labels is not None): weight_tensor = _weight_tensor(features, weight_column_name) loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor) - # Uses the deprecated API to set the tag explicitly. - # Without it, training and eval losses will show up in different graphs. - logging_ops.scalar_summary( - _summary_key(head_name, mkey.LOSS), weighted_average_loss) + # The name_scope escapism is needed to maintain the same summary tag + # after switching away from the now unsupported API. + with ops.name_scope(""): + summary_loss = array_ops.identity(weighted_average_loss) + summary.scalar(_summary_key(head_name, mkey.LOSS), summary_loss) if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: @@ -1484,8 +1484,12 @@ class _LossOnlyHead(Head): loss = self._loss_fn() if isinstance(loss, list): loss = math_ops.add_n(loss) - logging_ops.scalar_summary( - _summary_key(self.head_name, mkey.LOSS), loss) + # The name_scope escapism is needed to maintain the same summary tag + # after switching away from the now unsupported API. + with ops.name_scope(""): + summary_loss = array_ops.identity(loss) + summary.scalar(_summary_key(self.head_name, mkey.LOSS), + summary_loss) if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError("train_op_fn can not be None in TRAIN mode") @@ -2029,13 +2033,13 @@ def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold): def _streaming_precision_at_threshold(predictions, labels, weights, threshold): precision_tensor, update_op = metrics_lib.precision_at_thresholds( - labels, predictions, (threshold,),_float_weights_or_none(weights)) + labels, predictions, (threshold,), _float_weights_or_none(weights)) return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) def _streaming_recall_at_threshold(predictions, labels, weights, threshold): precision_tensor, update_op = metrics_lib.recall_at_thresholds( - labels, predictions, (threshold,),_float_weights_or_none(weights)) + labels, predictions, (threshold,), _float_weights_or_none(weights)) return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index b4d9c3fc6fb5906de93950e46a41fb97b24f779f..992b804f59ecd88fedc2fba10d3079f93c4fe83d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of k-means clustering on top of `Estimator` API.""" +"""Implementation of k-means clustering on top of `Estimator` API. + +This module is deprecated. Please use +@{tf.contrib.factorization.KMeansClustering} instead of +@{tf.contrib.learn.KMeansClustering}. It has a similar interface, but uses the +@{tf.estimator.Estimator} API instead of @{tf.contrib.learn.Estimator}. +""" from __future__ import absolute_import from __future__ import division @@ -29,12 +35,17 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops -from tensorflow.python.summary import summary from tensorflow.python.ops.control_flow_ops import with_dependencies from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training.session_run_hook import SessionRunArgs +from tensorflow.python.util.deprecation import deprecated + +_USE_TF_CONTRIB_FACTORIZATION = ( + 'Please use tf.contrib.factorization.KMeansClustering instead of' + ' tf.contrib.learn.KMeansClustering. It has a similar interface, but uses' + ' the tf.estimator.Estimator API instead of tf.contrib.learn.Estimator.') class _LossRelativeChangeHook(session_run_hook.SessionRunHook): @@ -106,7 +117,7 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config): """Model function for KMeansClustering estimator.""" assert labels is None, labels (all_scores, model_predictions, losses, - is_initialized, _, init_op, training_op) = clustering_ops.KMeans( + is_initialized, init_op, training_op) = clustering_ops.KMeans( _parse_tensor_or_dict(features), params.get('num_clusters'), initial_clusters=params.get('training_initial_clusters'), @@ -153,6 +164,7 @@ class KMeansClustering(estimator.Estimator): ALL_SCORES = 'all_scores' LOSS_OP_NAME = 'kmeans_loss' + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def __init__(self, num_clusters, model_dir=None, @@ -204,6 +216,7 @@ class KMeansClustering(estimator.Estimator): model_dir=model_dir, config=config) + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def predict_cluster_idx(self, input_fn=None): """Yields predicted cluster indices.""" key = KMeansClustering.CLUSTER_IDX @@ -212,6 +225,7 @@ class KMeansClustering(estimator.Estimator): for result in results: yield result[key] + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def score(self, input_fn=None, steps=None): """Predict total sum of distances to nearest clusters. @@ -229,6 +243,7 @@ class KMeansClustering(estimator.Estimator): self.evaluate( input_fn=input_fn, steps=steps)[KMeansClustering.SCORES]) + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def transform(self, input_fn=None, as_iterable=False): """Transforms each element to distances to cluster centers. @@ -255,6 +270,7 @@ class KMeansClustering(estimator.Estimator): else: return results + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def clusters(self): """Returns cluster centers.""" return super(KMeansClustering, self).get_variable_value(self.CLUSTERS) diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 9b55826e627c5198ba7f88505afb866a0f308553..307db76afe20a7743df16d169270a6f319497eb6 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -149,16 +149,16 @@ class Experiment(object): Args: estimator: Object implementing Estimator interface, which could be a - combination of ${tf.contrib.learn.Trainable} and - ${tf.contrib.learn.Evaluable} (deprecated), or - ${tf.estimator.`Estimator}. + combination of @{tf.contrib.learn.Trainable} and + @{tf.contrib.learn.Evaluable} (deprecated), or + @{tf.estimator.Estimator}. train_input_fn: function, returns features and labels for training. eval_input_fn: function, returns features and labels for evaluation. If `eval_steps` is `None`, this should be configured only to produce for a finite number of batches (generally, 1 epoch over the evaluation data). eval_metrics: `dict` of string, metric function. If `None`, default set is used. This should be `None` if the `estimator` is - ${tf.estimator.Estimator}. If metrics are provided they will be + @{tf.estimator.Estimator}. If metrics are provided they will be *appended* to the default set. train_steps: Perform this many steps of training. `None`, the default, means train forever. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py index eaf6ae4ed72148436c3d1aa3838b516c6025b0aa..82848be7df653dd60219317d28f233767746f544 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py @@ -42,16 +42,6 @@ class DataFeederTest(test.TestCase): with self.assertRaisesRegexp(TypeError, 'annot convert'): data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1) - def test_input_uint32(self): - data = np.matrix([[1, 2], [3, 4]], dtype=np.uint32) - self._assert_raises(data) - self._assert_raises(self._wrap_dict(data)) - - def test_input_uint64(self): - data = np.matrix([[1, 2], [3, 4]], dtype=np.uint64) - self._assert_raises(data) - self._assert_raises(self._wrap_dict(data)) - def _assert_dtype(self, expected_np_dtype, expected_tf_dtype, input_data): feeder = data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1) if isinstance(input_data, dict): @@ -87,6 +77,16 @@ class DataFeederTest(test.TestCase): self._assert_dtype(np.int64, dtypes.int64, data) self._assert_dtype(np.int64, dtypes.int64, self._wrap_dict(data)) + def test_input_uint32(self): + data = np.matrix([[1, 2], [3, 4]], dtype=np.uint32) + self._assert_dtype(np.uint32, dtypes.uint32, data) + self._assert_dtype(np.uint32, dtypes.uint32, self._wrap_dict(data)) + + def test_input_uint64(self): + data = np.matrix([[1, 2], [3, 4]], dtype=np.uint64) + self._assert_dtype(np.uint64, dtypes.uint64, data) + self._assert_dtype(np.uint64, dtypes.uint64, self._wrap_dict(data)) + def test_input_uint8(self): data = np.matrix([[1, 2], [3, 4]], dtype=np.uint8) self._assert_dtype(np.uint8, dtypes.uint8, data) diff --git a/tensorflow/contrib/learn/python/learn/learn_runner.py b/tensorflow/contrib/learn/python/learn/learn_runner.py index 9f9740ec492e8b71191aff17f70a007409525ccd..2af723a0d64822e81fa0fbeb106ab812de6ab4e8 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner.py @@ -165,7 +165,7 @@ def run(experiment_fn, output_dir=None, schedule=None, run_config=None, must be None. 2) It accepts two arguments `run_config` and `hparams`, which should be used to create the `Estimator` (`run_config` passed as `config` to its - constructor; `hparams` used as the hyper-paremeters of the model). + constructor; `hparams` used as the hyper-parameters of the model). It must return an `Experiment`. For this case, `output_dir` must be None. output_dir: Base output directory [Deprecated]. schedule: The name of the method in the `Experiment` to run. diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index ee8856ac34219971b886a9f2c4b4e8f2ae639697..49413092a6bae547ddd2cad272b1abb3af1de046 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -50,6 +50,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.training import saver from tensorflow.python.util import compat @@ -108,7 +109,11 @@ def build_standardized_signature_def(input_tensors, output_tensors, classes = _get_classification_classes(output_tensors) scores = _get_classification_scores(output_tensors) if classes is None and scores is None: - (_, classes), = output_tensors.items() + items = list(output_tensors.items()) + if items[0][1].dtype == dtypes.string: + (_, classes), = items + else: + (_, scores), = items return signature_def_utils.classification_signature_def( examples, classes, scores) elif _is_regression_problem(problem_type, input_tensors, output_tensors): @@ -616,7 +621,13 @@ def make_best_model_export_strategy(serving_input_fn, Returns: The string path to the exported directory. """ - + if not checkpoint_path: + # TODO(b/67425018): switch to + # checkpoint_path = estimator.latest_checkpoint() + # as soon as contrib is cleaned up and we can thus be sure that + # estimator is a tf.estimator.Estimator and not a + # tf.contrib.learn.Estimator + checkpoint_path = saver.latest_checkpoint(estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index d4de638338689d2775efe6988af3a058bb128c07..8313aa355d6d40596b40c39f28b64f46c1bb5719 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -76,7 +76,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest # TODO(ebrevdo): Remove once _linear is fully deprecated. -linear = rnn_cell_impl._linear # pylint: disable=protected-access +Linear = rnn_cell_impl._Linear # pylint: disable=protected-access,invalid-name def _extract_argmax_and_embed(embedding, @@ -645,7 +645,7 @@ def attention_decoder(decoder_inputs, query = array_ops.concat(query_list, 1) for a in xrange(num_heads): with variable_scope.variable_scope("Attention_%d" % a): - y = linear(query, attention_vec_size, True) + y = Linear(query, attention_vec_size, True)(query) y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) # Attention mask is a softmax of v^T * tanh(...). s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y), @@ -679,7 +679,9 @@ def attention_decoder(decoder_inputs, input_size = inp.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from input: %s" % inp.name) - x = linear([inp] + attns, input_size, True) + + inputs = [inp] + attns + x = Linear(inputs, input_size, True)(inputs) # Run the RNN. cell_output, state = cell(x, state) # Run the attention mechanism. @@ -691,7 +693,8 @@ def attention_decoder(decoder_inputs, attns = attention(state) with variable_scope.variable_scope("AttnOutputProjection"): - output = linear([cell_output] + attns, output_size, True) + inputs = [cell_output] + attns + output = Linear(inputs, output_size, True)(inputs) if loop_function is not None: prev = output outputs.append(output) diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index 810a3d34eee0a886fcf49ca3209547c9307a6e67..734bac17dc82a61fd4c85b6277625d4a35961958 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -10,152 +10,7 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_tests") - -cuda_py_tests( - name = "linear_operator_test", - size = "small", - srcs = ["python/kernel_tests/linear_operator_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "linear_operator_addition_test", - size = "small", - srcs = ["python/kernel_tests/linear_operator_addition_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "linear_operator_composition_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_composition_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], - tags = ["noasan"], # times out b/63678675 -) - -cuda_py_tests( - name = "linear_operator_diag_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_diag_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - ], -) - -cuda_py_tests( - name = "linear_operator_identity_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_identity_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - ], -) - -cuda_py_tests( - name = "linear_operator_full_matrix_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_full_matrix_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "linear_operator_tril_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_tril_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "linear_operator_udvh_update_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_udvh_update_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], - shard_count = 5, -) - -cuda_py_tests( - name = "linear_operator_util_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_util_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "linalg_py", @@ -176,11 +31,29 @@ py_library( "//tensorflow/python:random_seed", "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/ops/linalg", "//third_party/py/numpy", "@six_archive//:six", ], ) +cuda_py_test( + name = "linear_operator_addition_test", + size = "small", + srcs = ["python/kernel_tests/linear_operator_addition_test.py"], + additional_deps = [ + ":linalg_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py index 44421a6b7de0344a9a4a172ddc0900a44eb74450..4720692c3384ba1bede1f486c1b1e0e69d10a63a 100644 --- a/tensorflow/contrib/linalg/__init__.py +++ b/tensorflow/contrib/linalg/__init__.py @@ -21,8 +21,8 @@ See the @{$python/contrib.linalg} guide. @@LinearOperatorIdentity @@LinearOperatorScaledIdentity @@LinearOperatorFullMatrix -@@LinearOperatorTriL -@@LinearOperatorUDVHUpdate +@@LinearOperatorLowerTriangular +@@LinearOperatorLowRankUpdate @@LinearOperatorComposition @@add_operators @@ -33,14 +33,14 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member -from tensorflow.contrib.linalg.python.ops.linear_operator import * from tensorflow.contrib.linalg.python.ops.linear_operator_addition import * -from tensorflow.contrib.linalg.python.ops.linear_operator_composition import * -from tensorflow.contrib.linalg.python.ops.linear_operator_diag import * -from tensorflow.contrib.linalg.python.ops.linear_operator_full_matrix import * -from tensorflow.contrib.linalg.python.ops.linear_operator_identity import * -from tensorflow.contrib.linalg.python.ops.linear_operator_tril import * -from tensorflow.contrib.linalg.python.ops.linear_operator_udvh_update import * +from tensorflow.python.ops.linalg.linear_operator import * +from tensorflow.python.ops.linalg.linear_operator_composition import * +from tensorflow.python.ops.linalg.linear_operator_diag import * +from tensorflow.python.ops.linalg.linear_operator_full_matrix import * +from tensorflow.python.ops.linalg.linear_operator_identity import * +from tensorflow.python.ops.linalg.linear_operator_low_rank_update import * +from tensorflow.python.ops.linalg.linear_operator_lower_triangular import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py index 474648475579d3f840d18ba3dd3291b90755a93a..6a72df6dfd8d8c35211bab42b240b83d77160a02 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py @@ -19,10 +19,10 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib import linalg as linalg_lib from tensorflow.contrib.linalg.python.ops import linear_operator_addition from tensorflow.python.framework import random_seed from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops.linalg import linalg as linalg_lib from tensorflow.python.platform import test linalg = linalg_lib @@ -114,7 +114,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase): def test_diag_tril_diag(self): op1 = linalg.LinearOperatorDiag( [1., 1.], is_non_singular=True, name="diag_a") - op2 = linalg.LinearOperatorTriL( + op2 = linalg.LinearOperatorLowerTriangular( [[2., 0.], [0., 2.]], is_self_adjoint=True, is_non_singular=True, @@ -125,7 +125,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase): op_sum = add_operators([op1, op2, op3]) self.assertEqual(1, len(op_sum)) op = op_sum[0] - self.assertTrue(isinstance(op, linalg_lib.LinearOperatorTriL)) + self.assertTrue(isinstance(op, linalg_lib.LinearOperatorLowerTriangular)) self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval()) # The diag operators will be self-adjoint (because real and diagonal). @@ -140,7 +140,8 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase): op0 = linalg.LinearOperatorFullMatrix( [[-1., -1.], [-1., -1.]], name="matrix") op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a") - op2 = linalg.LinearOperatorTriL([[2., 0.], [1.5, 2.]], name="tril") + op2 = linalg.LinearOperatorLowerTriangular( + [[2., 0.], [1.5, 2.]], name="tril") op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b") with self.test_session(): op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator") @@ -189,7 +190,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): def test_tier_1_additions_done_by_tier_1(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorTriL([[1.]]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [linear_operator_addition._AddAndReturnTriL()], @@ -199,12 +200,12 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): # _BadAdder) was never reached. op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorTriL)) + self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular)) def test_tier_1_additions_done_by_tier_1_with_order_flipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorTriL([[1.]]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnTriL()], [linear_operator_addition._AddAndReturnDiag()], @@ -216,12 +217,12 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): # Tier 2 was never used (therefore, _BadAdder didn't raise). op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorTriL)) + self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular)) def test_cannot_add_everything_so_return_more_than_one_operator(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([2.]) - tril5 = linalg.LinearOperatorTriL([[5.]]) + tril5 = linalg.LinearOperatorLowerTriangular([[5.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], ] @@ -237,7 +238,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): if isinstance(op, linalg.LinearOperatorDiag): found_diag = True self.assertAllClose([[3.]], op.to_dense().eval()) - if isinstance(op, linalg.LinearOperatorTriL): + if isinstance(op, linalg.LinearOperatorLowerTriangular): found_tril = True self.assertAllClose([[5.]], op.to_dense().eval()) self.assertTrue(found_diag and found_tril) @@ -245,7 +246,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): def test_intermediate_tier_is_not_skipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorTriL([[1.]]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [_BadAdder()], @@ -369,14 +370,14 @@ class AddAndReturnTriLTest(test.TestCase): def test_diag_plus_tril(self): diag = linalg.LinearOperatorDiag([1., 2.]) - tril = linalg.LinearOperatorTriL([[10., 0.], [30., 0.]]) + tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]]) hints = linear_operator_addition._Hints( is_positive_definite=True, is_non_singular=True) self.assertTrue(self._adder.can_add(diag, diag)) self.assertTrue(self._adder.can_add(diag, tril)) operator = self._adder.add(diag, tril, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorTriL)) + self.assertTrue(isinstance(operator, linalg.LinearOperatorLowerTriangular)) with self.test_session(): self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval()) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py index 16c4c6e6d67f17d1674b8d1d39f006bc688bc6ce..86130a2c077ce14a7539b281ec809029bc05e071 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py @@ -22,14 +22,14 @@ import abc import six -from tensorflow.contrib.linalg.python.ops import linear_operator -from tensorflow.contrib.linalg.python.ops import linear_operator_diag -from tensorflow.contrib.linalg.python.ops import linear_operator_full_matrix -from tensorflow.contrib.linalg.python.ops import linear_operator_identity -from tensorflow.contrib.linalg.python.ops import linear_operator_tril from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.ops.linalg import linear_operator_diag +from tensorflow.python.ops.linalg import linear_operator_full_matrix +from tensorflow.python.ops.linalg import linear_operator_identity +from tensorflow.python.ops.linalg import linear_operator_lower_triangular __all__ = [] @@ -347,7 +347,7 @@ class _AddAndReturnTriL(_Adder): else: op_add_to_tensor, op_other = op2, op1 - return linear_operator_tril.LinearOperatorTriL( + return linear_operator_lower_triangular.LinearOperatorLowerTriangular( tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), is_non_singular=hints.is_non_singular, is_self_adjoint=hints.is_self_adjoint, @@ -397,7 +397,8 @@ def _type(operator): """Returns the type name constant (e.g. _TRIL) for operator.""" if isinstance(operator, linear_operator_diag.LinearOperatorDiag): return _DIAG - if isinstance(operator, linear_operator_tril.LinearOperatorTriL): + if isinstance(operator, + linear_operator_lower_triangular.LinearOperatorLowerTriangular): return _TRIL if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): return _MATRIX diff --git a/tensorflow/contrib/losses/BUILD b/tensorflow/contrib/losses/BUILD index f75b0aa1b3e6606b0c92ae94b15b12781fe8b777..33fbbe12d3926606c468d13bef2842b81a857edb 100644 --- a/tensorflow/contrib/losses/BUILD +++ b/tensorflow/contrib/losses/BUILD @@ -15,6 +15,7 @@ py_library( "__init__.py", "python/losses/__init__.py", "python/losses/loss_ops.py", + "python/metric_learning/metric_loss_ops.py", ], srcs_version = "PY2AND3", deps = [ @@ -50,6 +51,49 @@ py_test( ], ) +py_library( + name = "metric_learning_py", + srcs = [ + "python/metric_learning/__init__.py", + "python/metric_learning/metric_loss_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:util", + ], +) + +py_test( + name = "metric_loss_ops_test", + srcs = [ + "python/metric_learning/metric_loss_ops_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":metric_learning_py", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/losses/__init__.py b/tensorflow/contrib/losses/__init__.py index 790bf61367d85b79bae4b153328b229b10721b38..db58647d48f0f6f093ef4b71d1e8a7b79e611184 100644 --- a/tensorflow/contrib/losses/__init__.py +++ b/tensorflow/contrib/losses/__init__.py @@ -22,10 +22,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.losses.python import metric_learning # pylint: disable=wildcard-import from tensorflow.contrib.losses.python.losses import * # pylint: enable=wildcard-import - from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ @@ -43,5 +43,6 @@ _allowed_symbols = [ 'sigmoid_cross_entropy', 'softmax_cross_entropy', 'sparse_softmax_cross_entropy', + 'metric_learning' ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 1d2477b8b794240bd348cec7f626be794181ffb4..7c523ad49265aaf32c8d5a8ae04d3e93262a1b55 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.deprecation import deprecated_args __all__ = ["absolute_difference", "add_loss", @@ -623,8 +624,9 @@ def mean_pairwise_squared_error( @deprecated("2016-12-30", "Use tf.losses.cosine_distance instead.") +@deprecated_args(None, "dim is deprecated, use axis instead", "dim") def cosine_distance( - predictions, labels=None, dim=None, weights=1.0, scope=None): + predictions, labels=None, axis=None, weights=1.0, scope=None, dim=None): """Adds a cosine-distance loss to the training procedure. Note that the function assumes that `predictions` and `labels` are already @@ -633,10 +635,11 @@ def cosine_distance( Args: predictions: An arbitrary matrix. labels: A `Tensor` whose shape matches 'predictions' - dim: The dimension along which the cosine distance is computed. + axis: The dimension along which the cosine distance is computed. weights: Coefficients for the loss a scalar, a tensor of shape [batch_size] or a tensor whose shape matches `predictions`. scope: The scope for the operations performed in computing the loss. + dim: The old (deprecated) name for `axis`. Returns: A scalar `Tensor` representing the loss value. @@ -645,8 +648,12 @@ def cosine_distance( ValueError: If `predictions` shape doesn't match `labels` shape, or `weights` is `None`. """ - if dim is None: - raise ValueError("`dim` cannot be None.") + if dim is not None: + if axis is not None: + raise ValueError("Cannot specify both 'axis' and 'dim'") + axis = dim + if axis is None and dim is None: + raise ValueError("You must specify 'axis'.") with ops.name_scope(scope, "cosine_distance_loss", [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -655,5 +662,5 @@ def cosine_distance( labels = math_ops.to_float(labels) radial_diffs = math_ops.multiply(predictions, labels) - losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,]) + losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[axis,]) return compute_weighted_loss(losses, weights, scope=scope) diff --git a/tensorflow/contrib/losses/python/metric_learning/__init__.py b/tensorflow/contrib/losses/python/metric_learning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e551d6acafb5c565965503075e8416e01c20a71 --- /dev/null +++ b/tensorflow/contrib/losses/python/metric_learning/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for building neural network losses. + +See @{$python/contrib.losses}. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.losses.python.metric_learning.metric_loss_ops import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'contrastive_loss', + 'cluster_loss', + 'lifted_struct_loss', + 'npairs_loss', + 'npairs_loss_multilabel', + 'triplet_semihard_loss', +] +remove_undocumented(__name__, _allowed_symbols) + + diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a57ba51bcf0a292490dfaa9e556f6e5811ed66 --- /dev/null +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py @@ -0,0 +1,1031 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements various metric learning losses.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.summary import summary +try: + # pylint: disable=g-import-not-at-top + from sklearn import metrics + HAS_SKLEARN = True +except ImportError: + HAS_SKLEARN = False + + +def pairwise_distance(feature, squared=False): + """Computes the pairwise distance matrix with numerical stability. + + output[i, j] = || feature[i, :] - feature[j, :] ||_2 + + Args: + feature: 2-D Tensor of size [number of data, feature dimension]. + squared: Boolean, whether or not to square the pairwise distances. + + Returns: + pairwise_distances: 2-D Tensor of size [number of data, number of data]. + """ + pairwise_distances_squared = math_ops.add( + math_ops.reduce_sum( + math_ops.square(feature), + axis=[1], + keep_dims=True), + math_ops.reduce_sum( + math_ops.square( + array_ops.transpose(feature)), + axis=[0], + keep_dims=True)) - 2.0 * math_ops.matmul( + feature, array_ops.transpose(feature)) + + # Deal with numerical inaccuracies. Set small negatives to zero. + pairwise_distances_squared = math_ops.maximum(pairwise_distances_squared, 0.0) + # Get the mask where the zero distances are at. + error_mask = math_ops.less_equal(pairwise_distances_squared, 0.0) + + # Optionally take the sqrt. + if squared: + pairwise_distances = pairwise_distances_squared + else: + pairwise_distances = math_ops.sqrt( + pairwise_distances_squared + math_ops.to_float(error_mask) * 1e-16) + + # Undo conditionally adding 1e-16. + pairwise_distances = math_ops.multiply( + pairwise_distances, math_ops.to_float(math_ops.logical_not(error_mask))) + + num_data = array_ops.shape(feature)[0] + # Explicitly set diagonals to zero. + mask_offdiagonals = array_ops.ones_like(pairwise_distances) - array_ops.diag( + array_ops.ones([num_data])) + pairwise_distances = math_ops.multiply(pairwise_distances, mask_offdiagonals) + return pairwise_distances + + +def contrastive_loss(labels, embeddings_anchor, embeddings_positive, + margin=1.0): + """Computes the contrastive loss. + + This loss encourages the embedding to be close to each other for + the samples of the same label and the embedding to be far apart at least + by the margin constant for the samples of different labels. + See: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf + + Args: + labels: 1-D tf.int32 `Tensor` with shape [batch_size] of + binary labels indicating positive vs negative pair. + embeddings_anchor: 2-D float `Tensor` of embedding vectors for the anchor + images. Embeddings should be l2 normalized. + embeddings_positive: 2-D float `Tensor` of embedding vectors for the + positive images. Embeddings should be l2 normalized. + margin: margin term in the loss definition. + + Returns: + contrastive_loss: tf.float32 scalar. + """ + # Get per pair distances + distances = math_ops.sqrt( + math_ops.reduce_sum( + math_ops.square(embeddings_anchor - embeddings_positive), 1)) + + # Add contrastive loss for the siamese network. + # label here is {0,1} for neg, pos. + return math_ops.reduce_mean( + math_ops.to_float(labels) * math_ops.square(distances) + + (1. - math_ops.to_float(labels)) * + math_ops.square(math_ops.maximum(margin - distances, 0.)), + name='contrastive_loss') + + +def masked_maximum(data, mask, dim=1): + """Computes the axis wise maximum over chosen elements. + + Args: + data: 2-D float `Tensor` of size [n, m]. + mask: 2-D Boolean `Tensor` of size [n, m]. + dim: The dimension over which to compute the maximum. + + Returns: + masked_maximums: N-D `Tensor`. + The maximized dimension is of size 1 after the operation. + """ + axis_minimums = math_ops.reduce_min(data, dim, keep_dims=True) + masked_maximums = math_ops.reduce_max( + math_ops.multiply( + data - axis_minimums, mask), dim, keep_dims=True) + axis_minimums + return masked_maximums + + +def masked_minimum(data, mask, dim=1): + """Computes the axis wise minimum over chosen elements. + + Args: + data: 2-D float `Tensor` of size [n, m]. + mask: 2-D Boolean `Tensor` of size [n, m]. + dim: The dimension over which to compute the minimum. + + Returns: + masked_minimums: N-D `Tensor`. + The minimized dimension is of size 1 after the operation. + """ + axis_maximums = math_ops.reduce_max(data, dim, keep_dims=True) + masked_minimums = math_ops.reduce_min( + math_ops.multiply( + data - axis_maximums, mask), dim, keep_dims=True) + axis_maximums + return masked_minimums + + +def triplet_semihard_loss(labels, embeddings, margin=1.0): + """Computes the triplet loss with semi-hard negative mining. + + The loss encourages the positive distances (between a pair of embeddings with + the same labels) to be smaller than the minimum negative distance among + which are at least greater than the positive distance plus the margin constant + (called semi-hard negative) in the mini-batch. If no such negative exists, + uses the largest negative distance instead. + See: https://arxiv.org/abs/1503.03832. + + Args: + labels: 1-D tf.int32 `Tensor` with shape [batch_size] of + multiclass integer labels. + embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should + be l2 normalized. + margin: Float, margin term in the loss definition. + + Returns: + triplet_loss: tf.float32 scalar. + """ + # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. + lshape = array_ops.shape(labels) + assert lshape.shape == 1 + labels = array_ops.reshape(labels, [lshape[0], 1]) + + # Build pairwise squared distance matrix. + pdist_matrix = pairwise_distance(embeddings, squared=True) + # Build pairwise binary adjacency matrix. + adjacency = math_ops.equal(labels, array_ops.transpose(labels)) + # Invert so we can select negatives only. + adjacency_not = math_ops.logical_not(adjacency) + + batch_size = array_ops.size(labels) + + # Compute the mask. + pdist_matrix_tile = array_ops.tile(pdist_matrix, [batch_size, 1]) + mask = math_ops.logical_and( + array_ops.tile(adjacency_not, [batch_size, 1]), + math_ops.greater( + pdist_matrix_tile, array_ops.reshape( + array_ops.transpose(pdist_matrix), [-1, 1]))) + mask_final = array_ops.reshape( + math_ops.greater( + math_ops.reduce_sum( + math_ops.cast( + mask, dtype=dtypes.float32), 1, keep_dims=True), + 0.0), [batch_size, batch_size]) + mask_final = array_ops.transpose(mask_final) + + adjacency_not = math_ops.cast(adjacency_not, dtype=dtypes.float32) + mask = math_ops.cast(mask, dtype=dtypes.float32) + + # negatives_outside: smallest D_an where D_an > D_ap. + negatives_outside = array_ops.reshape( + masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size]) + negatives_outside = array_ops.transpose(negatives_outside) + + # negatives_inside: largest D_an. + negatives_inside = array_ops.tile( + masked_maximum(pdist_matrix, adjacency_not), [1, batch_size]) + semi_hard_negatives = array_ops.where( + mask_final, negatives_outside, negatives_inside) + + loss_mat = math_ops.add(margin, pdist_matrix - semi_hard_negatives) + + mask_positives = math_ops.cast( + adjacency, dtype=dtypes.float32) - array_ops.diag( + array_ops.ones([batch_size])) + + # In lifted-struct, the authors multiply 0.5 for upper triangular + # in semihard, they take all positive pairs except the diagonal. + num_positives = math_ops.reduce_sum(mask_positives) + + triplet_loss = math_ops.truediv( + math_ops.reduce_sum( + math_ops.maximum( + math_ops.multiply(loss_mat, mask_positives), 0.0)), + num_positives, + name='triplet_semihard_loss') + + return triplet_loss + + +# pylint: disable=line-too-long +def npairs_loss(labels, embeddings_anchor, embeddings_positive, + reg_lambda=0.002, print_losses=False): + """Computes the npairs loss. + + Npairs loss expects paired data where a pair is composed of samples from the + same labels and each pairs in the minibatch have different labels. The loss + has two components. The first component is the L2 regularizer on the + embedding vectors. The second component is the sum of cross entropy loss + which takes each row of the pair-wise similarity matrix as logits and + the remapped one-hot labels as labels. + + See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf + + Args: + labels: 1-D tf.int32 `Tensor` of shape [batch_size/2]. + embeddings_anchor: 2-D Tensor of shape [batch_size/2, embedding_dim] for the + embedding vectors for the anchor images. Embeddings should not be + l2 normalized. + embeddings_positive: 2-D Tensor of shape [batch_size/2, embedding_dim] for the + embedding vectors for the positive images. Embeddings should not be + l2 normalized. + reg_lambda: Float. L2 regularization term on the embedding vectors. + print_losses: Boolean. Option to print the xent and l2loss. + + Returns: + npairs_loss: tf.float32 scalar. + """ + # pylint: enable=line-too-long + # Add the regularizer on the embedding. + reg_anchor = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_anchor), 1)) + reg_positive = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_positive), 1)) + l2loss = math_ops.multiply( + 0.25 * reg_lambda, reg_anchor + reg_positive, name='l2loss') + + # Get per pair similarities. + similarity_matrix = math_ops.matmul( + embeddings_anchor, embeddings_positive, transpose_a=False, + transpose_b=True) + + # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. + lshape = array_ops.shape(labels) + assert lshape.shape == 1 + labels = array_ops.reshape(labels, [lshape[0], 1]) + + labels_remapped = math_ops.to_float( + math_ops.equal(labels, array_ops.transpose(labels))) + labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True) + + # Add the softmax loss. + xent_loss = nn.softmax_cross_entropy_with_logits( + logits=similarity_matrix, labels=labels_remapped) + xent_loss = math_ops.reduce_mean(xent_loss, name='xentropy') + + if print_losses: + xent_loss = logging_ops.Print( + xent_loss, ['cross entropy:', xent_loss, 'l2loss:', l2loss]) + + return l2loss + xent_loss + + +def _build_multilabel_adjacency(sparse_labels): + """Builds multilabel adjacency matrix. + + As of March 14th, 2017, there's no op for the dot product between + two sparse tensors in TF. However, there is `sparse_minimum` op which is + equivalent to an AND op between two sparse boolean tensors. + This computes the dot product between two sparse boolean inputs. + + Args: + sparse_labels: List of 1-D boolean sparse tensors. + + Returns: + adjacency_matrix: 2-D dense `Tensor`. + """ + num_pairs = len(sparse_labels) + adjacency_matrix = array_ops.zeros([num_pairs, num_pairs]) + for i in range(num_pairs): + for j in range(num_pairs): + sparse_dot_product = math_ops.to_float( + sparse_ops.sparse_reduce_sum(sparse_ops.sparse_minimum( + sparse_labels[i], sparse_labels[j]))) + sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 0) + sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 1) + one_hot_matrix = array_ops.pad(sparse_dot_product, + [[i, num_pairs-i-1], + [j, num_pairs-j-1]], 'CONSTANT') + adjacency_matrix += one_hot_matrix + + return adjacency_matrix + + +def npairs_loss_multilabel(sparse_labels, embeddings_anchor, + embeddings_positive, reg_lambda=0.002, + print_losses=False): + r"""Computes the npairs loss with multilabel data. + + Npairs loss expects paired data where a pair is composed of samples from the + same labels and each pairs in the minibatch have different labels. The loss + has two components. The first component is the L2 regularizer on the + embedding vectors. The second component is the sum of cross entropy loss + which takes each row of the pair-wise similarity matrix as logits and + the remapped one-hot labels as labels. Here, the similarity is defined by the + dot product between two embedding vectors. S_{i,j} = f(x_i)^T f(x_j) + + To deal with multilabel inputs, we use the count of label intersection + i.e. L_{i,j} = | set_of_labels_for(i) \cap set_of_labels_for(j) | + Then we normalize each rows of the count based label matrix so that each row + sums to one. + + Args: + sparse_labels: List of 1-D Boolean `SparseTensor` of dense_shape + [batch_size/2, num_classes] labels for the anchor-pos pairs. + embeddings_anchor: 2-D `Tensor` of shape [batch_size/2, embedding_dim] for + the embedding vectors for the anchor images. Embeddings should not be + l2 normalized. + embeddings_positive: 2-D `Tensor` of shape [batch_size/2, embedding_dim] for + the embedding vectors for the positive images. Embeddings should not be + l2 normalized. + reg_lambda: Float. L2 regularization term on the embedding vectors. + print_losses: Boolean. Option to print the xent and l2loss. + + Returns: + npairs_loss: tf.float32 scalar. + Raises: + TypeError: When the specified sparse_labels is not a `SparseTensor`. + """ + if False in [isinstance( + l, sparse_tensor.SparseTensor) for l in sparse_labels]: + raise TypeError( + 'sparse_labels must be a list of SparseTensors, but got %s' % str( + sparse_labels)) + + with ops.name_scope('NpairsLossMultiLabel'): + # Add the regularizer on the embedding. + reg_anchor = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_anchor), 1)) + reg_positive = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_positive), 1)) + l2loss = math_ops.multiply(0.25 * reg_lambda, + reg_anchor + reg_positive, name='l2loss') + + # Get per pair similarities. + similarity_matrix = math_ops.matmul( + embeddings_anchor, embeddings_positive, transpose_a=False, + transpose_b=True) + + # TODO(coreylynch): need to check the sparse values + # TODO(coreylynch): are composed only of 0's and 1's. + + multilabel_adjacency_matrix = _build_multilabel_adjacency(sparse_labels) + labels_remapped = math_ops.to_float(multilabel_adjacency_matrix) + labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True) + + # Add the softmax loss. + xent_loss = nn.softmax_cross_entropy_with_logits( + logits=similarity_matrix, labels=labels_remapped) + xent_loss = math_ops.reduce_mean(xent_loss, name='xentropy') + + if print_losses: + xent_loss = logging_ops.Print( + xent_loss, ['cross entropy:', xent_loss, 'l2loss:', l2loss]) + + return l2loss + xent_loss + + +def lifted_struct_loss(labels, embeddings, margin=1.0): + """Computes the lifted structured loss. + + The loss encourages the positive distances (between a pair of embeddings + with the same labels) to be smaller than any negative distances (between a + pair of embeddings with different labels) in the mini-batch in a way + that is differentiable with respect to the embedding vectors. + See: https://arxiv.org/abs/1511.06452. + + Args: + labels: 1-D tf.int32 `Tensor` with shape [batch_size] of + multiclass integer labels. + embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should not + be l2 normalized. + margin: Float, margin term in the loss definition. + + Returns: + lifted_loss: tf.float32 scalar. + """ + # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. + lshape = array_ops.shape(labels) + assert lshape.shape == 1 + labels = array_ops.reshape(labels, [lshape[0], 1]) + + # Build pairwise squared distance matrix. + pairwise_distances = pairwise_distance(embeddings) + + # Build pairwise binary adjacency matrix. + adjacency = math_ops.equal(labels, array_ops.transpose(labels)) + # Invert so we can select negatives only. + adjacency_not = math_ops.logical_not(adjacency) + + batch_size = array_ops.size(labels) + + diff = margin - pairwise_distances + mask = math_ops.cast(adjacency_not, dtype=dtypes.float32) + # Safe maximum: Temporarily shift negative distances + # above zero before taking max. + # this is to take the max only among negatives. + row_minimums = math_ops.reduce_min(diff, 1, keep_dims=True) + row_negative_maximums = math_ops.reduce_max( + math_ops.multiply( + diff - row_minimums, mask), 1, keep_dims=True) + row_minimums + + # Compute the loss. + # Keep track of matrix of maximums where M_ij = max(m_i, m_j) + # where m_i is the max of alpha - negative D_i's. + # This matches the Caffe loss layer implementation at: + # https://github.com/rksltnl/Caffe-Deep-Metric-Learning-CVPR16/blob/0efd7544a9846f58df923c8b992198ba5c355454/src/caffe/layers/lifted_struct_similarity_softmax_layer.cpp # pylint: disable=line-too-long + + max_elements = math_ops.maximum( + row_negative_maximums, array_ops.transpose(row_negative_maximums)) + diff_tiled = array_ops.tile(diff, [batch_size, 1]) + mask_tiled = array_ops.tile(mask, [batch_size, 1]) + max_elements_vect = array_ops.reshape( + array_ops.transpose(max_elements), [-1, 1]) + + loss_exp_left = array_ops.reshape( + math_ops.reduce_sum(math_ops.multiply( + math_ops.exp( + diff_tiled - max_elements_vect), + mask_tiled), 1, keep_dims=True), [batch_size, batch_size]) + + loss_mat = max_elements + math_ops.log( + loss_exp_left + array_ops.transpose(loss_exp_left)) + # Add the positive distance. + loss_mat += pairwise_distances + + mask_positives = math_ops.cast( + adjacency, dtype=dtypes.float32) - array_ops.diag( + array_ops.ones([batch_size])) + + # *0.5 for upper triangular, and another *0.5 for 1/2 factor for loss^2. + num_positives = math_ops.reduce_sum(mask_positives) / 2.0 + + lifted_loss = math_ops.truediv( + 0.25 * math_ops.reduce_sum( + math_ops.square( + math_ops.maximum( + math_ops.multiply(loss_mat, mask_positives), 0.0))), + num_positives, + name='liftedstruct_loss') + return lifted_loss + + +def update_1d_tensor(y, index, value): + """Updates 1d tensor y so that y[index] = value. + + Args: + y: 1-D Tensor. + index: index of y to modify. + value: new value to write at y[index]. + + Returns: + y_mod: 1-D Tensor. Tensor y after the update. + """ + value = array_ops.squeeze(value) + # modify the 1D tensor x at index with value. + # ex) chosen_ids = update_1D_tensor(chosen_ids, cluster_idx, best_medoid) + y_before = array_ops.slice(y, [0], [index]) + y_after = array_ops.slice(y, [index + 1], [-1]) + y_mod = array_ops.concat([y_before, [value], y_after], 0) + return y_mod + + +def get_cluster_assignment(pairwise_distances, centroid_ids): + """Assign data points to the neareset centroids. + + Tensorflow has numerical instability and doesn't always choose + the data point with theoretically zero distance as it's nearest neighbor. + Thus, for each centroid in centroid_ids, explicitly assign + the centroid itself as the nearest centroid. + This is done through the mask tensor and the constraint_vect tensor. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + centroid_ids: 1-D Tensor of centroid indices. + + Returns: + y_fixed: 1-D tensor of cluster assignment. + """ + predictions = math_ops.argmin( + array_ops.gather(pairwise_distances, centroid_ids), dimension=0) + batch_size = array_ops.shape(pairwise_distances)[0] + + # Deal with numerical instability + mask = math_ops.reduce_any(array_ops.one_hot( + centroid_ids, batch_size, True, False, axis=-1, dtype=dtypes.bool), + axis=0) + constraint_one_hot = math_ops.multiply( + array_ops.one_hot(centroid_ids, + batch_size, + array_ops.constant(1, dtype=dtypes.int64), + array_ops.constant(0, dtype=dtypes.int64), + axis=0, + dtype=dtypes.int64), + math_ops.to_int64(math_ops.range(array_ops.shape(centroid_ids)[0]))) + constraint_vect = math_ops.reduce_sum( + array_ops.transpose(constraint_one_hot), axis=0) + + y_fixed = array_ops.where(mask, constraint_vect, predictions) + return y_fixed + + +def compute_facility_energy(pairwise_distances, centroid_ids): + """Compute the average travel distance to the assigned centroid. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + centroid_ids: 1-D Tensor of indices. + + Returns: + facility_energy: dtypes.float32 scalar. + """ + return -1.0 * math_ops.reduce_sum( + math_ops.reduce_min( + array_ops.gather(pairwise_distances, centroid_ids), axis=0)) + + +def compute_clustering_score(labels, predictions, margin_type): + """Computes the clustering score via sklearn.metrics functions. + + There are various ways to compute the clustering score. Intuitively, + we want to measure the agreement of two clustering assignments (labels vs + predictions) ignoring the permutations and output a score from zero to one. + (where the values close to one indicate significant agreement). + This code supports following scoring functions: + nmi: normalized mutual information + ami: adjusted mutual information + ari: adjusted random index + vmeasure: v-measure + const: indicator checking whether the two clusterings are the same. + See http://scikit-learn.org/stable/modules/classes.html#clustering-metrics + for the detailed descriptions. + Args: + labels: 1-D Tensor. ground truth cluster assignment. + predictions: 1-D Tensor. predicted cluster assignment. + margin_type: Type of structured margin to use. Default is nmi. + Returns: + clustering_score: dtypes.float32 scalar. + The possible valid values are from zero to one. + Zero means the worst clustering and one means the perfect clustering. + Raises: + ValueError: margin_type is not recognized. + """ + margin_type_to_func = { + 'nmi': _compute_nmi_score, + 'ami': _compute_ami_score, + 'ari': _compute_ari_score, + 'vmeasure': _compute_vmeasure_score, + 'const': _compute_zeroone_score + } + + if margin_type not in margin_type_to_func: + raise ValueError('Unrecognized margin_type: %s' % margin_type) + clustering_score_fn = margin_type_to_func[margin_type] + return array_ops.squeeze(clustering_score_fn(labels, predictions)) + + +def _compute_nmi_score(labels, predictions): + return math_ops.to_float( + script_ops.py_func( + metrics.normalized_mutual_info_score, [labels, predictions], + [dtypes.float64], + name='nmi')) + + +def _compute_ami_score(labels, predictions): + ami_score = math_ops.to_float( + script_ops.py_func( + metrics.adjusted_mutual_info_score, [labels, predictions], + [dtypes.float64], + name='ami')) + return math_ops.maximum(0.0, ami_score) + + +def _compute_ari_score(labels, predictions): + ari_score = math_ops.to_float( + script_ops.py_func( + metrics.adjusted_rand_score, [labels, predictions], [dtypes.float64], + name='ari')) + # ari score can go below 0 + # http://scikit-learn.org/stable/modules/clustering.html#adjusted-rand-score + return math_ops.maximum(0.0, ari_score) + + +def _compute_vmeasure_score(labels, predictions): + vmeasure_score = math_ops.to_float( + script_ops.py_func( + metrics.v_measure_score, [labels, predictions], [dtypes.float64], + name='vmeasure')) + return math_ops.maximum(0.0, vmeasure_score) + + +def _compute_zeroone_score(labels, predictions): + zeroone_score = math_ops.to_float( + math_ops.equal( + math_ops.reduce_sum( + math_ops.to_int32(math_ops.equal(labels, predictions))), + array_ops.shape(labels)[0])) + return zeroone_score + + +def _find_loss_augmented_facility_idx(pairwise_distances, labels, chosen_ids, + candidate_ids, margin_multiplier, + margin_type): + """Find the next centroid that maximizes the loss augmented inference. + + This function is a subroutine called from compute_augmented_facility_locations + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + chosen_ids: 1-D Tensor of current centroid indices. + candidate_ids: 1-D Tensor of candidate indices. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + integer index. + """ + num_candidates = array_ops.shape(candidate_ids)[0] + + pairwise_distances_chosen = array_ops.gather(pairwise_distances, chosen_ids) + pairwise_distances_candidate = array_ops.gather( + pairwise_distances, candidate_ids) + pairwise_distances_chosen_tile = array_ops.tile( + pairwise_distances_chosen, [1, num_candidates]) + + candidate_scores = -1.0 * math_ops.reduce_sum( + array_ops.reshape( + math_ops.reduce_min( + array_ops.concat([ + pairwise_distances_chosen_tile, + array_ops.reshape(pairwise_distances_candidate, [1, -1]) + ], 0), + axis=0, + keep_dims=True), [num_candidates, -1]), + axis=1) + + nmi_scores = array_ops.zeros([num_candidates]) + iteration = array_ops.constant(0) + + def func_cond(iteration, nmi_scores): + del nmi_scores # Unused in func_cond() + return iteration < num_candidates + + def func_body(iteration, nmi_scores): + predictions = get_cluster_assignment( + pairwise_distances, + array_ops.concat([chosen_ids, [candidate_ids[iteration]]], 0)) + nmi_score_i = compute_clustering_score(labels, predictions, margin_type) + pad_before = array_ops.zeros([iteration]) + pad_after = array_ops.zeros([num_candidates - 1 - iteration]) + # return 1 - NMI score as the structured loss. + # because NMI is higher the better [0,1]. + return iteration + 1, nmi_scores + array_ops.concat( + [pad_before, [1.0 - nmi_score_i], pad_after], 0) + + _, nmi_scores = control_flow_ops.while_loop( + func_cond, func_body, [iteration, nmi_scores]) + + candidate_scores = math_ops.add( + candidate_scores, margin_multiplier * nmi_scores) + + argmax_index = math_ops.to_int32( + math_ops.argmax(candidate_scores, dimension=0)) + + return candidate_ids[argmax_index] + + +def compute_augmented_facility_locations(pairwise_distances, labels, all_ids, + margin_multiplier, margin_type): + """Computes the centroid locations. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + all_ids: 1-D Tensor of all data indices. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + chosen_ids: 1-D Tensor of chosen centroid indices. + """ + + def func_cond_augmented(iteration, chosen_ids): + del chosen_ids # Unused argument in func_cond_augmented. + return iteration < num_classes + + def func_body_augmented(iteration, chosen_ids): + # find a new facility location to add + # based on the clustering score and the NMI score + candidate_ids = array_ops.setdiff1d(all_ids, chosen_ids)[0] + new_chosen_idx = _find_loss_augmented_facility_idx(pairwise_distances, + labels, chosen_ids, + candidate_ids, + margin_multiplier, + margin_type) + chosen_ids = array_ops.concat([chosen_ids, [new_chosen_idx]], 0) + return iteration + 1, chosen_ids + + num_classes = array_ops.size(array_ops.unique(labels)[0]) + chosen_ids = array_ops.constant(0, dtype=dtypes.int32, shape=[0]) + + # num_classes get determined at run time based on the sampled batch. + iteration = array_ops.constant(0) + + _, chosen_ids = control_flow_ops.while_loop( + func_cond_augmented, + func_body_augmented, [iteration, chosen_ids], + shape_invariants=[iteration.get_shape(), tensor_shape.TensorShape( + [None])]) + return chosen_ids + + +def update_medoid_per_cluster(pairwise_distances, pairwise_distances_subset, + labels, chosen_ids, cluster_member_ids, + cluster_idx, margin_multiplier, margin_type): + """Updates the cluster medoid per cluster. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + pairwise_distances_subset: 2-D Tensor of pairwise distances for one cluster. + labels: 1-D Tensor of ground truth cluster assignment. + chosen_ids: 1-D Tensor of cluster centroid indices. + cluster_member_ids: 1-D Tensor of cluster member indices for one cluster. + cluster_idx: Index of this one cluster. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + chosen_ids: Updated 1-D Tensor of cluster centroid indices. + """ + + def func_cond(iteration, scores_margin): + del scores_margin # Unused variable scores_margin. + return iteration < num_candidates + + def func_body(iteration, scores_margin): + # swap the current medoid with the candidate cluster member + candidate_medoid = math_ops.to_int32(cluster_member_ids[iteration]) + tmp_chosen_ids = update_1d_tensor(chosen_ids, cluster_idx, candidate_medoid) + predictions = get_cluster_assignment(pairwise_distances, tmp_chosen_ids) + metric_score = compute_clustering_score(labels, predictions, margin_type) + pad_before = array_ops.zeros([iteration]) + pad_after = array_ops.zeros([num_candidates - 1 - iteration]) + return iteration + 1, scores_margin + array_ops.concat( + [pad_before, [1.0 - metric_score], pad_after], 0) + + # pairwise_distances_subset is of size [p, 1, 1, p], + # the intermediate dummy dimensions at + # [1, 2] makes this code work in the edge case where p=1. + # this happens if the cluster size is one. + scores_fac = -1.0 * math_ops.reduce_sum( + array_ops.squeeze(pairwise_distances_subset, [1, 2]), axis=0) + + iteration = array_ops.constant(0) + num_candidates = array_ops.size(cluster_member_ids) + scores_margin = array_ops.zeros([num_candidates]) + + _, scores_margin = control_flow_ops.while_loop(func_cond, func_body, + [iteration, scores_margin]) + candidate_scores = math_ops.add(scores_fac, margin_multiplier * scores_margin) + + argmax_index = math_ops.to_int32( + math_ops.argmax(candidate_scores, dimension=0)) + + best_medoid = math_ops.to_int32(cluster_member_ids[argmax_index]) + chosen_ids = update_1d_tensor(chosen_ids, cluster_idx, best_medoid) + return chosen_ids + + +def update_all_medoids(pairwise_distances, predictions, labels, chosen_ids, + margin_multiplier, margin_type): + """Updates all cluster medoids a cluster at a time. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + predictions: 1-D Tensor of predicted cluster assignment. + labels: 1-D Tensor of ground truth cluster assignment. + chosen_ids: 1-D Tensor of cluster centroid indices. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + chosen_ids: Updated 1-D Tensor of cluster centroid indices. + """ + + def func_cond_augmented_pam(iteration, chosen_ids): + del chosen_ids # Unused argument. + return iteration < num_classes + + def func_body_augmented_pam(iteration, chosen_ids): + """Call the update_medoid_per_cluster subroutine.""" + mask = math_ops.equal( + math_ops.to_int64(predictions), math_ops.to_int64(iteration)) + this_cluster_ids = array_ops.where(mask) + + pairwise_distances_subset = array_ops.transpose( + array_ops.gather( + array_ops.transpose( + array_ops.gather(pairwise_distances, this_cluster_ids)), + this_cluster_ids)) + + chosen_ids = update_medoid_per_cluster(pairwise_distances, + pairwise_distances_subset, labels, + chosen_ids, this_cluster_ids, + iteration, margin_multiplier, + margin_type) + return iteration + 1, chosen_ids + + unique_class_ids = array_ops.unique(labels)[0] + num_classes = array_ops.size(unique_class_ids) + iteration = array_ops.constant(0) + + _, chosen_ids = control_flow_ops.while_loop( + func_cond_augmented_pam, func_body_augmented_pam, [iteration, chosen_ids]) + return chosen_ids + + +def compute_augmented_facility_locations_pam(pairwise_distances, + labels, + margin_multiplier, + margin_type, + chosen_ids, + pam_max_iter=5): + """Refine the cluster centroids with PAM local search. + + For fixed iterations, alternate between updating the cluster assignment + and updating cluster medoids. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + chosen_ids: 1-D Tensor of initial estimate of cluster centroids. + pam_max_iter: Number of refinement iterations. + + Returns: + chosen_ids: Updated 1-D Tensor of cluster centroid indices. + """ + for _ in range(pam_max_iter): + # update the cluster assignment given the chosen_ids (S_pred) + predictions = get_cluster_assignment(pairwise_distances, chosen_ids) + + # update the medoids per each cluster + chosen_ids = update_all_medoids(pairwise_distances, predictions, labels, + chosen_ids, margin_multiplier, margin_type) + + return chosen_ids + + +def compute_gt_cluster_score(pairwise_distances, labels): + """Compute ground truth facility location score. + + Loop over each unique classes and compute average travel distances. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + + Returns: + gt_cluster_score: dtypes.float32 score. + """ + unique_class_ids = array_ops.unique(labels)[0] + num_classes = array_ops.size(unique_class_ids) + iteration = array_ops.constant(0) + gt_cluster_score = array_ops.constant(0.0, dtype=dtypes.float32) + + def func_cond(iteration, gt_cluster_score): + del gt_cluster_score # Unused argument. + return iteration < num_classes + + def func_body(iteration, gt_cluster_score): + """Per each cluster, compute the average travel distance.""" + mask = math_ops.equal(labels, unique_class_ids[iteration]) + this_cluster_ids = array_ops.where(mask) + pairwise_distances_subset = array_ops.transpose( + array_ops.gather( + array_ops.transpose( + array_ops.gather(pairwise_distances, this_cluster_ids)), + this_cluster_ids)) + this_cluster_score = -1.0 * math_ops.reduce_min( + math_ops.reduce_sum( + pairwise_distances_subset, axis=0)) + return iteration + 1, gt_cluster_score + this_cluster_score + + _, gt_cluster_score = control_flow_ops.while_loop( + func_cond, func_body, [iteration, gt_cluster_score]) + return gt_cluster_score + + +def cluster_loss(labels, + embeddings, + margin_multiplier, + enable_pam_finetuning=True, + margin_type='nmi', + print_losses=False): + """Computes the clustering loss. + + The following structured margins are supported: + nmi: normalized mutual information + ami: adjusted mutual information + ari: adjusted random index + vmeasure: v-measure + const: indicator checking whether the two clusterings are the same. + + Args: + labels: 2-D Tensor of labels of shape [batch size, 1] + embeddings: 2-D Tensor of embeddings of shape + [batch size, embedding dimension]. Embeddings should be l2 normalized. + margin_multiplier: float32 scalar. multiplier on the structured margin term + See section 3.2 of paper for discussion. + enable_pam_finetuning: Boolean, Whether to run local pam refinement. + See section 3.4 of paper for discussion. + margin_type: Type of structured margin to use. See section 3.2 of + paper for discussion. Can be 'nmi', 'ami', 'ari', 'vmeasure', 'const'. + print_losses: Boolean. Option to print the loss. + + Paper: https://arxiv.org/abs/1612.01213. + + Returns: + clustering_loss: A float32 scalar `Tensor`. + Raises: + ImportError: If sklearn dependency is not installed. + """ + if not HAS_SKLEARN: + raise ImportError('Cluster loss depends on sklearn.') + pairwise_distances = pairwise_distance(embeddings) + labels = array_ops.squeeze(labels) + all_ids = math_ops.range(array_ops.shape(embeddings)[0]) + + # Compute the loss augmented inference and get the cluster centroids. + chosen_ids = compute_augmented_facility_locations(pairwise_distances, labels, + all_ids, margin_multiplier, + margin_type) + # Given the predicted centroids, compute the clustering score. + score_pred = compute_facility_energy(pairwise_distances, chosen_ids) + + # Branch whether to use PAM finetuning. + if enable_pam_finetuning: + # Initialize with augmented facility solution. + chosen_ids = compute_augmented_facility_locations_pam(pairwise_distances, + labels, + margin_multiplier, + margin_type, + chosen_ids) + score_pred = compute_facility_energy(pairwise_distances, chosen_ids) + + # Given the predicted centroids, compute the cluster assignments. + predictions = get_cluster_assignment(pairwise_distances, chosen_ids) + + # Compute the clustering (i.e. NMI) score between the two assignments. + clustering_score_pred = compute_clustering_score(labels, predictions, + margin_type) + + # Compute the clustering score from labels. + score_gt = compute_gt_cluster_score(pairwise_distances, labels) + + # Compute the hinge loss. + clustering_loss = math_ops.maximum( + score_pred + margin_multiplier * (1.0 - clustering_score_pred) - score_gt, + 0.0, + name='clustering_loss') + clustering_loss.set_shape([]) + + if print_losses: + clustering_loss = logging_ops.Print( + clustering_loss, + ['clustering_loss: ', clustering_loss, array_ops.shape( + clustering_loss)]) + + # Clustering specific summary. + summary.scalar('losses/score_pred', score_pred) + summary.scalar('losses/' + margin_type, clustering_score_pred) + summary.scalar('losses/score_gt', score_gt) + + return clustering_loss diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ec539ab42b4e0ba90a2a1f379a1d4d4b49d11f3 --- /dev/null +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py @@ -0,0 +1,562 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for triplet_semihard_loss.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.losses.python import metric_learning as metric_loss_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test +try: + # pylint: disable=g-import-not-at-top + from sklearn import datasets + from sklearn import metrics + HAS_SKLEARN = True +except ImportError: + HAS_SKLEARN = False + + +def pairwise_distance_np(feature, squared=False): + """Computes the pairwise distance matrix in numpy. + + Args: + feature: 2-D numpy array of size [number of data, feature dimension] + squared: Boolean. If true, output is the pairwise squared euclidean + distance matrix; else, output is the pairwise euclidean distance matrix. + + Returns: + pairwise_distances: 2-D numpy array of size + [number of data, number of data]. + """ + triu = np.triu_indices(feature.shape[0], 1) + upper_tri_pdists = np.linalg.norm(feature[triu[1]] - feature[triu[0]], axis=1) + if squared: + upper_tri_pdists **= 2. + num_data = feature.shape[0] + pairwise_distances = np.zeros((num_data, num_data)) + pairwise_distances[np.triu_indices(num_data, 1)] = upper_tri_pdists + # Make symmetrical. + pairwise_distances = pairwise_distances + pairwise_distances.T - np.diag( + pairwise_distances.diagonal()) + return pairwise_distances + + +class ContrastiveLossTest(test.TestCase): + + def testContrastive(self): + with self.test_session(): + num_data = 10 + feat_dim = 6 + margin = 1.0 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + labels = np.random.randint(0, 2, size=(num_data,)).astype(np.float32) + + # Compute the loss in NP + dist = np.sqrt( + np.sum(np.square(embeddings_anchor - embeddings_positive), axis=1)) + loss_np = np.mean( + labels * np.square(dist) + + (1.0 - labels) * np.square(np.maximum(margin - dist, 0.0))) + # Compute the loss with TF + loss_tf = metric_loss_ops.contrastive_loss( + labels=ops.convert_to_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + margin=margin) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +class TripletSemiHardLossTest(test.TestCase): + + def testTripletSemiHard(self): + with self.test_session(): + num_data = 10 + feat_dim = 6 + margin = 1.0 + num_classes = 4 + + embedding = np.random.rand(num_data, feat_dim).astype(np.float32) + labels = np.random.randint( + 0, num_classes, size=(num_data)).astype(np.float32) + + # Reshape labels to compute adjacency matrix. + labels_reshaped = np.reshape(labels, (labels.shape[0], 1)) + # Compute the loss in NP. + adjacency = np.equal(labels_reshaped, labels_reshaped.T) + + pdist_matrix = pairwise_distance_np(embedding, squared=True) + loss_np = 0.0 + num_positives = 0.0 + for i in range(num_data): + for j in range(num_data): + if adjacency[i][j] > 0.0 and i != j: + num_positives += 1.0 + + pos_distance = pdist_matrix[i][j] + neg_distances = [] + + for k in range(num_data): + if adjacency[i][k] == 0: + neg_distances.append(pdist_matrix[i][k]) + + # Sort by distance. + neg_distances.sort() + chosen_neg_distance = neg_distances[0] + + for l in range(len(neg_distances)): + chosen_neg_distance = neg_distances[l] + if chosen_neg_distance > pos_distance: + break + + loss_np += np.maximum( + 0.0, margin - chosen_neg_distance + pos_distance) + + loss_np /= num_positives + + # Compute the loss in TF. + loss_tf = metric_loss_ops.triplet_semihard_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embedding), + margin=margin) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +class LiftedStructLossTest(test.TestCase): + + def testLiftedStruct(self): + with self.test_session(): + num_data = 10 + feat_dim = 6 + margin = 1.0 + num_classes = 4 + + embedding = np.random.rand(num_data, feat_dim).astype(np.float32) + labels = np.random.randint( + 0, num_classes, size=(num_data)).astype(np.float32) + # Reshape labels to compute adjacency matrix. + labels_reshaped = np.reshape(labels, (labels.shape[0], 1)) + + # Compute the loss in NP + adjacency = np.equal(labels_reshaped, labels_reshaped.T) + pdist_matrix = pairwise_distance_np(embedding) + loss_np = 0.0 + num_constraints = 0.0 + for i in range(num_data): + for j in range(num_data): + if adjacency[i][j] > 0.0 and i != j: + d_pos = pdist_matrix[i][j] + negs = [] + for k in range(num_data): + if not adjacency[i][k]: + negs.append(margin - pdist_matrix[i][k]) + for l in range(num_data): + if not adjacency[j][l]: + negs.append(margin - pdist_matrix[j][l]) + + negs = np.array(negs) + max_elem = np.max(negs) + negs -= max_elem + negs = np.exp(negs) + soft_maximum = np.log(np.sum(negs)) + max_elem + + num_constraints += 1.0 + this_loss = max(soft_maximum + d_pos, 0) + loss_np += this_loss * this_loss + + loss_np = loss_np / num_constraints / 2.0 + + # Compute the loss in TF + loss_tf = metric_loss_ops.lifted_struct_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embedding), + margin=margin) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +def convert_to_list_of_sparse_tensor(np_matrix): + list_of_sparse_tensors = [] + nrows, ncols = np_matrix.shape + for i in range(nrows): + sp_indices = [] + for j in range(ncols): + if np_matrix[i][j] == 1: + sp_indices.append([j]) + + num_non_zeros = len(sp_indices) + list_of_sparse_tensors.append(sparse_tensor.SparseTensor( + indices=np.array(sp_indices), + values=np.ones((num_non_zeros,)), + dense_shape=np.array([ncols,]))) + + return list_of_sparse_tensors + + +class NpairsLossTest(test.TestCase): + + def testNpairs(self): + with self.test_session(): + num_data = 15 + feat_dim = 6 + num_classes = 5 + reg_lambda = 0.02 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + + labels = np.random.randint( + 0, num_classes, size=(num_data)).astype(np.float32) + # Reshape labels to compute adjacency matrix. + labels_reshaped = np.reshape(labels, (labels.shape[0], 1)) + + # Compute the loss in NP + reg_term = np.mean(np.sum(np.square(embeddings_anchor), 1)) + reg_term += np.mean(np.sum(np.square(embeddings_positive), 1)) + reg_term *= 0.25 * reg_lambda + + similarity_matrix = np.matmul(embeddings_anchor, embeddings_positive.T) + + labels_remapped = np.equal( + labels_reshaped, labels_reshaped.T).astype(np.float32) + labels_remapped /= np.sum(labels_remapped, axis=1, keepdims=True) + + xent_loss = math_ops.reduce_mean(nn.softmax_cross_entropy_with_logits( + logits=ops.convert_to_tensor(similarity_matrix), + labels=ops.convert_to_tensor(labels_remapped))).eval() + loss_np = xent_loss + reg_term + + # Compute the loss in TF + loss_tf = metric_loss_ops.npairs_loss( + labels=ops.convert_to_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +class NpairsLossMultiLabelTest(test.TestCase): + + def testNpairsMultiLabelLossWithSingleLabelEqualsNpairsLoss(self): + with self.test_session(): + num_data = 15 + feat_dim = 6 + reg_lambda = 0.02 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + labels = np.arange(num_data) + labels = np.reshape(labels, -1) + + # Compute vanila npairs loss. + loss_npairs = metric_loss_ops.npairs_loss( + labels=ops.convert_to_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda).eval() + + # Compute npairs multilabel loss. + labels_one_hot = np.identity(num_data) + loss_npairs_multilabel = metric_loss_ops.npairs_loss_multilabel( + sparse_labels=convert_to_list_of_sparse_tensor(labels_one_hot), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda).eval() + + self.assertAllClose(loss_npairs, loss_npairs_multilabel) + + def testNpairsMultiLabel(self): + with self.test_session(): + num_data = 15 + feat_dim = 6 + num_classes = 10 + reg_lambda = 0.02 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + + labels = np.random.randint(0, 2, (num_data, num_classes)) + # set entire column to one so that each row has at least one bit set. + labels[:, -1] = 1 + + # Compute the loss in NP + reg_term = np.mean(np.sum(np.square(embeddings_anchor), 1)) + reg_term += np.mean(np.sum(np.square(embeddings_positive), 1)) + reg_term *= 0.25 * reg_lambda + + similarity_matrix = np.matmul(embeddings_anchor, embeddings_positive.T) + + labels_remapped = np.dot(labels, labels.T).astype(np.float) + labels_remapped /= np.sum(labels_remapped, 1, keepdims=True) + + xent_loss = math_ops.reduce_mean(nn.softmax_cross_entropy_with_logits( + logits=ops.convert_to_tensor(similarity_matrix), + labels=ops.convert_to_tensor(labels_remapped))).eval() + loss_np = xent_loss + reg_term + + # Compute the loss in TF + loss_tf = metric_loss_ops.npairs_loss_multilabel( + sparse_labels=convert_to_list_of_sparse_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda) + loss_tf = loss_tf.eval() + + self.assertAllClose(loss_np, loss_tf) + + +def compute_ground_truth_cluster_score(feat, y): + y_unique = np.unique(y) + score_gt_np = 0.0 + for c in y_unique: + feat_subset = feat[y == c, :] + pdist_subset = pairwise_distance_np(feat_subset) + score_gt_np += -1.0 * np.min(np.sum(pdist_subset, axis=0)) + score_gt_np = score_gt_np.astype(np.float32) + return score_gt_np + + +def compute_cluster_loss_numpy(feat, + y, + margin_multiplier=1.0, + enable_pam_finetuning=True): + if enable_pam_finetuning: + facility = ForwardGreedyFacility( + n_clusters=np.unique(y).size).pam_augmented_fit(feat, y, + margin_multiplier) + else: + facility = ForwardGreedyFacility( + n_clusters=np.unique(y).size).loss_augmented_fit(feat, y, + margin_multiplier) + + score_augmented = facility.score_aug_ + score_gt = compute_ground_truth_cluster_score(feat, y) + return np.maximum(np.float32(0.0), score_augmented - score_gt) + + +class ForwardGreedyFacility(object): + + def __init__(self, n_clusters=8): + self.n_clusters = n_clusters + self.center_ics_ = None + + def _check_init_args(self): + # Check n_clusters. + if (self.n_clusters is None or self.n_clusters <= 0 or + not isinstance(self.n_clusters, int)): + raise ValueError('n_clusters has to be nonnegative integer.') + + def loss_augmented_fit(self, feat, y, loss_mult): + """Fit K-Medoids to the provided data.""" + self._check_init_args() + # Check that the array is good and attempt to convert it to + # Numpy array if possible. + feat = self._check_array(feat) + # Apply distance metric to get the distance matrix. + pdists = pairwise_distance_np(feat) + + num_data = feat.shape[0] + candidate_ids = list(range(num_data)) + candidate_scores = np.zeros(num_data,) + subset = [] + + k = 0 + while k < self.n_clusters: + candidate_scores = [] + for i in candidate_ids: + # push i to subset. + subset.append(i) + marginal_cost = -1.0 * np.sum(np.min(pdists[:, subset], axis=1)) + loss = 1.0 - metrics.normalized_mutual_info_score( + y, self._get_cluster_ics(pdists, subset)) + candidate_scores.append(marginal_cost + loss_mult * loss) + # remove i from subset. + subset.pop() + + # push i_star to subset. + i_star = candidate_ids[np.argmax(candidate_scores)] + subset.append(i_star) + # remove i_star from candidate indices. + candidate_ids.remove(i_star) + k += 1 + + # Expose labels_ which are the assignments of + # the training data to clusters. + self.labels_ = self._get_cluster_ics(pdists, subset) + # Expose cluster centers, i.e. medoids. + self.cluster_centers_ = feat.take(subset, axis=0) + # Expose indices of chosen cluster centers. + self.center_ics_ = subset + # Expose the score = -\sum_{i \in V} min_{j \in S} || x_i - x_j || + self.score_ = np.float32(-1.0) * self._get_facility_distance(pdists, subset) + self.score_aug_ = self.score_ + loss_mult * ( + 1.0 - metrics.normalized_mutual_info_score( + y, self._get_cluster_ics(pdists, subset))) + self.score_aug_ = self.score_aug_.astype(np.float32) + # Expose the chosen cluster indices. + self.subset_ = subset + return self + + def _augmented_update_medoid_ics_in_place(self, pdists, y_gt, cluster_ics, + medoid_ics, loss_mult): + for cluster_idx in range(self.n_clusters): + # y_pred = self._get_cluster_ics(D, medoid_ics) + # Don't prematurely do the assignment step. + # Do this after we've updated all cluster medoids. + y_pred = cluster_ics + + if sum(y_pred == cluster_idx) == 0: + # Cluster is empty. + continue + + curr_score = ( + -1.0 * np.sum( + pdists[medoid_ics[cluster_idx], y_pred == cluster_idx]) + + loss_mult * (1.0 - metrics.normalized_mutual_info_score( + y_gt, y_pred))) + + pdist_in = pdists[y_pred == cluster_idx, :] + pdist_in = pdist_in[:, y_pred == cluster_idx] + + all_scores_fac = np.sum(-1.0 * pdist_in, axis=1) + all_scores_loss = [] + for i in range(y_pred.size): + if y_pred[i] != cluster_idx: + continue + # remove this cluster's current centroid + medoid_ics_i = medoid_ics[:cluster_idx] + medoid_ics[cluster_idx + 1:] + # add this new candidate to the centroid list + medoid_ics_i += [i] + y_pred_i = self._get_cluster_ics(pdists, medoid_ics_i) + all_scores_loss.append(loss_mult * ( + 1.0 - metrics.normalized_mutual_info_score(y_gt, y_pred_i))) + + all_scores = all_scores_fac + all_scores_loss + max_score_idx = np.argmax(all_scores) + max_score = all_scores[max_score_idx] + + if max_score > curr_score: + medoid_ics[cluster_idx] = np.where( + y_pred == cluster_idx)[0][max_score_idx] + + def pam_augmented_fit(self, feat, y, loss_mult): + pam_max_iter = 5 + self._check_init_args() + feat = self._check_array(feat) + pdists = pairwise_distance_np(feat) + self.loss_augmented_fit(feat, y, loss_mult) + print('PAM -1 (before PAM): score: %f, score_aug: %f' % ( + self.score_, self.score_aug_)) + # Initialize from loss augmented facility location + subset = self.center_ics_ + for iter_ in range(pam_max_iter): + # update the cluster assignment + cluster_ics = self._get_cluster_ics(pdists, subset) + # update the medoid for each clusters + self._augmented_update_medoid_ics_in_place(pdists, y, cluster_ics, subset, + loss_mult) + self.score_ = np.float32(-1.0) * self._get_facility_distance( + pdists, subset) + self.score_aug_ = self.score_ + loss_mult * ( + 1.0 - metrics.normalized_mutual_info_score( + y, self._get_cluster_ics(pdists, subset))) + self.score_aug_ = self.score_aug_.astype(np.float32) + print('PAM iter: %d, score: %f, score_aug: %f' % (iter_, self.score_, + self.score_aug_)) + + self.center_ics_ = subset + self.labels_ = cluster_ics + return self + + def _check_array(self, feat): + # Check that the number of clusters is less than or equal to + # the number of samples + if self.n_clusters > feat.shape[0]: + raise ValueError('The number of medoids ' + '({}) '.format( + self.n_clusters) + 'must be larger than the number ' + + 'of samples ({})'.format(feat.shape[0])) + return feat + + def _get_cluster_ics(self, pdists, subset): + """Returns cluster indices for pdist and current medoid indices.""" + # Assign data points to clusters based on + # which cluster assignment yields + # the smallest distance` + cluster_ics = np.argmin(pdists[subset, :], axis=0) + return cluster_ics + + def _get_facility_distance(self, pdists, subset): + return np.sum(np.min(pdists[subset, :], axis=0)) + + +class ClusterLossTest(test.TestCase): + + def _genClusters(self, n_samples, n_clusters): + blobs = datasets.make_blobs( + n_samples=n_samples, centers=n_clusters) + embedding, labels = blobs + embedding = (embedding - embedding.mean(axis=0)) / embedding.std(axis=0) + embedding = embedding.astype(np.float32) + return embedding, labels + + def testClusteringLossPAMOff(self): + if not HAS_SKLEARN: + return + with self.test_session(): + margin_multiplier = 10.0 + embeddings, labels = self._genClusters(n_samples=128, n_clusters=64) + + loss_np = compute_cluster_loss_numpy( + embeddings, labels, margin_multiplier, enable_pam_finetuning=False) + loss_tf = metric_loss_ops.cluster_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embeddings), + margin_multiplier=margin_multiplier, + enable_pam_finetuning=False) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + def testClusteringLossPAMOn(self): + if not HAS_SKLEARN: + return + with self.test_session(): + margin_multiplier = 10.0 + embeddings, labels = self._genClusters(n_samples=128, n_clusters=64) + + loss_np = compute_cluster_loss_numpy( + embeddings, labels, margin_multiplier, enable_pam_finetuning=True) + loss_tf = metric_loss_ops.cluster_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embeddings), + margin_multiplier=margin_multiplier, + enable_pam_finetuning=True) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index e0cfab0b26d8f106e83f6223d057c9ef5f395f4f..be7c790ee9e11ca90c0756011003a919f7d930f8 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -484,6 +484,7 @@ $(wildcard tensorflow/core/*/*/*main.cc) \ $(wildcard tensorflow/core/debug/*.cc) \ $(wildcard tensorflow/core/framework/op_gen_lib.cc) \ $(wildcard tensorflow/core/graph/dot.*) \ +$(wildcard tensorflow/core/lib/db/*) \ $(wildcard tensorflow/core/lib/gif/*) \ $(wildcard tensorflow/core/lib/io/zlib*) \ $(wildcard tensorflow/core/lib/io/record*) \ diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index ff298e84ad13c720c1d5b989d437fb9402257b85..1fda907074545d9b78a902182e4cec9e4212c22d 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -142,6 +142,7 @@ tensorflow/core/kernels/cwise_op_sqrt.cc tensorflow/core/kernels/cwise_op_sigmoid.cc tensorflow/core/kernels/cwise_op_sign.cc tensorflow/core/kernels/cwise_op_select.cc +tensorflow/core/kernels/cwise_op_round.cc tensorflow/core/kernels/cwise_op_rsqrt.cc tensorflow/core/kernels/cwise_op_reciprocal.cc tensorflow/core/kernels/cwise_op_neg.cc @@ -160,6 +161,7 @@ tensorflow/core/kernels/cwise_op_invert.cc tensorflow/core/kernels/cwise_op_greater_equal.cc tensorflow/core/kernels/cwise_op_greater.cc tensorflow/core/kernels/cwise_op_floor_div.cc +tensorflow/core/kernels/cwise_op_floor_mod.cc tensorflow/core/kernels/cwise_op_floor.cc tensorflow/core/kernels/cwise_op_exp.cc tensorflow/core/kernels/cwise_op_equal_to_2.cc diff --git a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc index dd479147749e286caa34aa658a8a51f865bcee20..7e2e96e160167ae68d3bdabacbbbeb45df61778f 100644 --- a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc +++ b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc @@ -54,15 +54,13 @@ class BytesInUseOp : public MemoryStatsOp { }; // Register this op on GPU only, see comment for MaxBytesInUse for reason -REGISTER_KERNEL_BUILDER( - Name("BytesInUse").Device(DEVICE_GPU).HostMemory("out"), - BytesInUseOp); +REGISTER_KERNEL_BUILDER(Name("BytesInUse").Device(DEVICE_GPU).HostMemory("out"), + BytesInUseOp); #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER( - Name("BytesInUse").Device(DEVICE_SYCL).HostMemory("out"), - MaxBytesInUseOp); -#endif // TENSORFLOW_USE_SYCL + Name("BytesInUse").Device(DEVICE_SYCL).HostMemory("out"), MaxBytesInUseOp); +#endif // TENSORFLOW_USE_SYCL // Op that measures the total memory (in bytes) of a device. class BytesLimitOp : public MemoryStatsOp { diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py index 303c02dfa409bb7410233f0005d9f3cb0b5bc11e..2932ae1c8df32cd936cff932b061571c513fda79 100644 --- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py +++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py @@ -749,7 +749,7 @@ def meta_graph_transform( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) - # Append newly added initalizers to collection. + # Append newly added initializers to collection. _add_new_inits_to_collection(meta_graph_def, updated_initializer_names) # Copy signature_defs, excluding any pruned nodes diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index a9bce65e55f36da0b930fcff619e9ed8841ef4c2..2c48882d0ea70bfdfa85730a2701c19cf76cb6e5 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -22,6 +22,10 @@ See the @{$python/contrib.metrics} guide. @@streaming_recall_at_thresholds @@streaming_precision @@streaming_precision_at_thresholds +@@streaming_false_positive_rate +@@streaming_false_positive_rate_at_thresholds +@@streaming_false_negative_rate +@@streaming_false_negative_rate_at_thresholds @@streaming_auc @@streaming_curve_points @@streaming_recall_at_k @@ -80,8 +84,12 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positive_rate +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positive_rate_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_mean diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 76986d0156dada75abcc559d9db6b9addf26cccc..85c8e9038ac5642d0dbb20aea968474e0d7aa5f4 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -565,6 +565,213 @@ def streaming_recall(predictions, labels, weights=None, updates_collections=updates_collections, name=name) +def _true_negatives(labels, predictions, weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Sum the weights of true negatives. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will + be cast to `bool`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric + value variable should be added to. + updates_collections: An optional list of collections that the metric update + ops should be added to. + name: An optional variable_scope name. + + Returns: + value_tensor: A `Tensor` representing the current value of the metric. + update_op: An operation that accumulates the error from a batch of data. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope( + name, 'true_negatives', (predictions, labels, weights)): + + predictions, labels, weights = _remove_squeezable_dimensions( + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), + math_ops.equal(predictions, False)) + return _count_condition(is_true_negative, weights, metrics_collections, + updates_collections) + + +def streaming_false_positive_rate(predictions, labels, weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes the false positive rate of predictions with respect to labels. + + The `false_positive_rate` function creates two local variables, + `false_positives` and `true_negatives`, that are used to compute the + false positive rate. This value is ultimately returned as + `false_positive_rate`, an idempotent operation that simply divides + `false_positives` by the sum of `false_positives` and `true_negatives`. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_positive_rate`. `update_op` weights each prediction by the + corresponding value in `weights`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will + be cast to `bool`. + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_positive_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_positive_rate: Scalar float `Tensor` with the value of + `false_positives` divided by the sum of `false_positives` and + `true_negatives`. + update_op: `Operation` that increments `false_positives` and + `true_negatives` variables appropriately and whose value matches + `false_positive_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope( + name, 'false_positive_rate', (predictions, labels, weights)): + predictions, labels, weights = _remove_squeezable_dimensions( + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + + false_p, false_positives_update_op = metrics.false_positives( + labels, predictions, weights, metrics_collections=None, + updates_collections=None, name=None) + true_n, true_negatives_update_op = _true_negatives( + labels, predictions, weights, metrics_collections=None, + updates_collections=None, name=None) + + def compute_fpr(fp, tn, name): + return array_ops.where( + math_ops.greater(fp + tn, 0), + math_ops.div(fp, fp + tn), + 0, + name) + + fpr = compute_fpr(false_p, true_n, 'value') + update_op = compute_fpr( + false_positives_update_op, true_negatives_update_op, 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fpr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return fpr, update_op + + +def streaming_false_negative_rate(predictions, labels, weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes the false negative rate of predictions with respect to labels. + + The `false_negative_rate` function creates two local variables, + `false_negatives` and `true_positives`, that are used to compute the + false positive rate. This value is ultimately returned as + `false_negative_rate`, an idempotent operation that simply divides + `false_negatives` by the sum of `false_negatives` and `true_positives`. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_negative_rate`. `update_op` weights each prediction by the + corresponding value in `weights`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will + be cast to `bool`. + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_negative_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_negative_rate: Scalar float `Tensor` with the value of + `false_negatives` divided by the sum of `false_negatives` and + `true_positives`. + update_op: `Operation` that increments `false_negatives` and + `true_positives` variables appropriately and whose value matches + `false_negative_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope( + name, 'false_negative_rate', (predictions, labels, weights)): + predictions, labels, weights = _remove_squeezable_dimensions( + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + + false_n, false_negatives_update_op = metrics.false_negatives( + labels, predictions, weights, metrics_collections=None, + updates_collections=None, name=None) + true_p, true_positives_update_op = metrics.true_positives( + labels, predictions, weights, metrics_collections=None, + updates_collections=None, name=None) + + def compute_fnr(fn, tp, name): + return array_ops.where( + math_ops.greater(fn + tp, 0), + math_ops.div(fn, fn + tp), + 0, + name) + + fnr = compute_fnr(false_n, true_p, 'value') + update_op = compute_fnr( + false_negatives_update_op, true_positives_update_op, 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fnr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return fnr, update_op + + def _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=None, includes=None): """Computes true_positives, false_negatives, true_negatives, false_positives. @@ -1114,6 +1321,142 @@ def streaming_recall_at_thresholds(predictions, labels, thresholds, updates_collections=updates_collections, name=name) +def streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds, weights=None, metrics_collections=None, + updates_collections=None, name=None): + """Computes various fpr values for different `thresholds` on `predictions`. + + The `streaming_false_positive_rate_at_thresholds` function creates two + local variables, `false_positives`, `true_negatives`, for various values of + thresholds. `false_positive_rate[i]` is defined as the total weight + of values in `predictions` above `thresholds[i]` whose corresponding entry in + `labels` is `False`, divided by the total weight of `False` values in `labels` + (`false_positives[i] / (false_positives[i] + true_negatives[i])`). + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_positive_rate`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + labels: A `bool` `Tensor` whose shape matches `predictions`. + thresholds: A python list or tuple of float thresholds in `[0, 1]`. + weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and + must be broadcastable to `labels` (i.e., all dimensions must be either + `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_positive_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_positive_rate: A float `Tensor` of shape `[len(thresholds)]`. + update_op: An operation that increments the `false_positives` and + `true_negatives` variables that are used in the computation of + `false_positive_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope( + name, 'false_positive_rate_at_thresholds', + (predictions, labels, weights)): + values, update_ops = _streaming_confusion_matrix_at_thresholds( + predictions, labels, thresholds, weights, includes=('fp', 'tn')) + + # Avoid division by zero. + epsilon = 1e-7 + def compute_fpr(fp, tn, name): + return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name) + + fpr = compute_fpr(values['fp'], values['tn'], 'value') + update_op = compute_fpr( + update_ops['fp'], update_ops['tn'], 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fpr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return fpr, update_op + + +def streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds, weights=None, metrics_collections=None, + updates_collections=None, name=None): + """Computes various fnr values for different `thresholds` on `predictions`. + + The `streaming_false_negative_rate_at_thresholds` function creates two + local variables, `false_negatives`, `true_positives`, for various values of + thresholds. `false_negative_rate[i]` is defined as the total weight + of values in `predictions` above `thresholds[i]` whose corresponding entry in + `labels` is `False`, divided by the total weight of `True` values in `labels` + (`false_negatives[i] / (false_negatives[i] + true_positives[i])`). + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_positive_rate`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + labels: A `bool` `Tensor` whose shape matches `predictions`. + thresholds: A python list or tuple of float thresholds in `[0, 1]`. + weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and + must be broadcastable to `labels` (i.e., all dimensions must be either + `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_negative_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_negative_rate: A float `Tensor` of shape `[len(thresholds)]`. + update_op: An operation that increments the `false_negatives` and + `true_positives` variables that are used in the computation of + `false_negative_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope( + name, 'false_negative_rate_at_thresholds', + (predictions, labels, weights)): + values, update_ops = _streaming_confusion_matrix_at_thresholds( + predictions, labels, thresholds, weights, includes=('fn', 'tp')) + + # Avoid division by zero. + epsilon = 1e-7 + def compute_fnr(fn, tp, name): + return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name) + + fnr = compute_fnr(values['fn'], values['tp'], 'value') + update_op = compute_fnr( + update_ops['fn'], update_ops['tp'], 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fnr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return fnr, update_op + + def _at_k_name(name, k=None, class_id=None): if k is not None: name = '%s_at_%d' % (name, k) @@ -2479,8 +2822,12 @@ __all__ = [ 'streaming_accuracy', 'streaming_auc', 'streaming_curve_points', + 'streaming_false_negative_rate', + 'streaming_false_negative_rate_at_thresholds', 'streaming_false_negatives', 'streaming_false_negatives_at_thresholds', + 'streaming_false_positive_rate', + 'streaming_false_positive_rate_at_thresholds', 'streaming_false_positives', 'streaming_false_positives_at_thresholds', 'streaming_mean', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 9b959b43a9db8baac5b37524e81bfbb11d6ad868..c5fcc20abd4927c5408071bae8fa8620cd4d7eb2 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1355,6 +1355,262 @@ class StreamingRecallTest(test.TestCase): self.assertEqual(0, recall.eval()) +class StreamingFPRTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_false_positive_rate( + predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) + _assert_local_variables(self, ( + 'false_positive_rate/false_positives/count:0', + 'false_positive_rate/true_negatives/count:0')) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + mean, _ = metrics.streaming_false_positive_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [mean]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.streaming_false_positive_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(update_op) + + # Then verify idempotency. + initial_fpr = fpr.eval() + for _ in range(10): + self.assertEqual(initial_fpr, fpr.eval()) + + def testAllCorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(np_inputs) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fpr.eval()) + + def testSomeCorrect(self): + predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4)) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(0.5, update_op.eval()) + self.assertAlmostEqual(0.5, fpr.eval()) + + def testWeighted1d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[2], [5]]) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fp = 2.0 + 5.0 + weighted_f = (2.0 + 2.0) + (5.0 + 5.0) + expected_fpr = weighted_fp / weighted_f + self.assertAlmostEqual(expected_fpr, update_op.eval()) + self.assertAlmostEqual(expected_fpr, fpr.eval()) + + def testWeighted2d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fp = 1.0 + 3.0 + weighted_f = (1.0 + 4.0) + (2.0 + 3.0) + expected_fpr = weighted_fp / weighted_f + self.assertAlmostEqual(expected_fpr, update_op.eval()) + self.assertAlmostEqual(expected_fpr, fpr.eval()) + + def testAllIncorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(1 - np_inputs) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(1, fpr.eval()) + + def testZeroFalsePositivesAndTrueNegativesGivesZeroFPR(self): + predictions = array_ops.ones((1, 4)) + labels = array_ops.ones((1, 4)) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fpr.eval()) + + +class StreamingFNRTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_false_negative_rate( + predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) + _assert_local_variables(self, ( + 'false_negative_rate/false_negatives/count:0', + 'false_negative_rate/true_positives/count:0')) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + mean, _ = metrics.streaming_false_negative_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [mean]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.streaming_false_negative_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(update_op) + + # Then verify idempotency. + initial_fnr = fnr.eval() + for _ in range(10): + self.assertEqual(initial_fnr, fnr.eval()) + + def testAllCorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(np_inputs) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fnr.eval()) + + def testSomeCorrect(self): + predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4)) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(0.5, update_op.eval()) + self.assertAlmostEqual(0.5, fnr.eval()) + + def testWeighted1d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[2], [5]]) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fn = 2.0 + 5.0 + weighted_t = (2.0 + 2.0) + (5.0 + 5.0) + expected_fnr = weighted_fn / weighted_t + self.assertAlmostEqual(expected_fnr, update_op.eval()) + self.assertAlmostEqual(expected_fnr, fnr.eval()) + + def testWeighted2d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fn = 2.0 + 4.0 + weighted_t = (2.0 + 3.0) + (1.0 + 4.0) + expected_fnr = weighted_fn / weighted_t + self.assertAlmostEqual(expected_fnr, update_op.eval()) + self.assertAlmostEqual(expected_fnr, fnr.eval()) + + def testAllIncorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(1 - np_inputs) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(1, fnr.eval()) + + def testZeroFalseNegativesAndTruePositivesGivesZeroFNR(self): + predictions = array_ops.zeros((1, 4)) + labels = array_ops.zeros((1, 4)) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fnr.eval()) + + class StreamingCurvePointsTest(test.TestCase): def setUp(self): @@ -1900,116 +2156,639 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): for _ in range(10): sess.run(update_op) - # Then verify idempotency. - initial_sensitivity = sensitivity.eval() - for _ in range(10): - self.assertAlmostEqual(initial_sensitivity, sensitivity.eval(), 5) + # Then verify idempotency. + initial_sensitivity = sensitivity.eval() + for _ in range(10): + self.assertAlmostEqual(initial_sensitivity, sensitivity.eval(), 5) + + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(inputs) + specificity, update_op = metrics.streaming_sensitivity_at_specificity( + predictions, labels, specificity=0.7) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertEqual(1, sess.run(update_op)) + self.assertEqual(1, specificity.eval()) + + def testSomeCorrectHighSpecificity(self): + predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9] + labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + specificity, update_op = metrics.streaming_sensitivity_at_specificity( + predictions, labels, specificity=0.8) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(0.8, sess.run(update_op)) + self.assertAlmostEqual(0.8, specificity.eval()) + + def testSomeCorrectLowSpecificity(self): + predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + specificity, update_op = metrics.streaming_sensitivity_at_specificity( + predictions, labels, specificity=0.4) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(0.6, sess.run(update_op)) + self.assertAlmostEqual(0.6, specificity.eval()) + + def testWeighted(self): + predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + weights = constant_op.constant(weights_values) + specificity, update_op = metrics.streaming_sensitivity_at_specificity( + predictions, labels, weights=weights, specificity=0.4) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(0.675, sess.run(update_op)) + self.assertAlmostEqual(0.675, specificity.eval()) + + +# TODO(nsilberman): Break this up into two sets of tests. +class StreamingPrecisionRecallThresholdsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_precision_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0]) + _assert_local_variables(self, ( + 'precision_at_thresholds/true_positives:0', + 'precision_at_thresholds/false_positives:0',)) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + prec, _ = metrics.streaming_precision_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + metrics_collections=[my_collection_name]) + rec, _ = metrics.streaming_recall_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [prec, rec]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, precision_op = metrics.streaming_precision_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + updates_collections=[my_collection_name]) + _, recall_op = metrics.streaming_recall_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + updates_collections=[my_collection_name]) + self.assertListEqual( + ops.get_collection(my_collection_name), [precision_op, recall_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + thresholds = [0, 0.5, 1.0] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates, then verify idempotency. + sess.run([prec_op, rec_op]) + initial_prec = prec.eval() + initial_rec = rec.eval() + for _ in range(10): + sess.run([prec_op, rec_op]) + self.assertAllClose(initial_prec, prec.eval()) + self.assertAllClose(initial_rec, rec.eval()) + + # TODO(nsilberman): fix tests (passing but incorrect). + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(inputs) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertEqual(1, prec.eval()) + self.assertEqual(1, rec.eval()) + + def testSomeCorrect(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0.5, prec.eval()) + self.assertAlmostEqual(0.5, rec.eval()) + + def testAllIncorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0, prec.eval()) + self.assertAlmostEqual(0, rec.eval()) + + def testWeights1d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds, weights=weights) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds, weights=weights) + + [prec_low, prec_high] = array_ops.split( + value=prec, num_or_size_splits=2, axis=0) + prec_low = array_ops.reshape(prec_low, shape=()) + prec_high = array_ops.reshape(prec_high, shape=()) + [rec_low, rec_high] = array_ops.split( + value=rec, num_or_size_splits=2, axis=0) + rec_low = array_ops.reshape(rec_low, shape=()) + rec_high = array_ops.reshape(rec_high, shape=()) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(1.0, prec_low.eval(), places=5) + self.assertAlmostEqual(0.0, prec_high.eval(), places=5) + self.assertAlmostEqual(1.0, rec_low.eval(), places=5) + self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + + def testWeights2d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds, weights=weights) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds, weights=weights) + + [prec_low, prec_high] = array_ops.split( + value=prec, num_or_size_splits=2, axis=0) + prec_low = array_ops.reshape(prec_low, shape=()) + prec_high = array_ops.reshape(prec_high, shape=()) + [rec_low, rec_high] = array_ops.split( + value=rec, num_or_size_splits=2, axis=0) + rec_low = array_ops.reshape(rec_low, shape=()) + rec_high = array_ops.reshape(rec_high, shape=()) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(1.0, prec_low.eval(), places=5) + self.assertAlmostEqual(0.0, prec_high.eval(), places=5) + self.assertAlmostEqual(1.0, rec_low.eval(), places=5) + self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + + def testExtremeThresholds(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) + thresholds = [-1.0, 2.0] # lower/higher than any values + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + [prec_low, prec_high] = array_ops.split( + value=prec, num_or_size_splits=2, axis=0) + [rec_low, rec_high] = array_ops.split( + value=rec, num_or_size_splits=2, axis=0) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0.75, prec_low.eval()) + self.assertAlmostEqual(0.0, prec_high.eval()) + self.assertAlmostEqual(1.0, rec_low.eval()) + self.assertAlmostEqual(0.0, rec_high.eval()) + + def testZeroLabelsPredictions(self): + with self.test_session() as sess: + predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) + labels = array_ops.zeros([4]) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0, prec.eval(), 6) + self.assertAlmostEqual(0, rec.eval(), 6) + + def testWithMultipleUpdates(self): + num_samples = 1000 + batch_size = 10 + num_batches = int(num_samples / batch_size) + + # Create the labels and data. + labels = np.random.randint(0, 2, size=(num_samples, 1)) + noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) + predictions = 0.4 + 0.2 * labels + noise + predictions[predictions > 1] = 1 + predictions[predictions < 0] = 0 + thresholds = [0.3] + + tp = 0 + fp = 0 + fn = 0 + tn = 0 + for i in range(num_samples): + if predictions[i] > thresholds[0]: + if labels[i] == 1: + tp += 1 + else: + fp += 1 + else: + if labels[i] == 1: + fn += 1 + else: + tn += 1 + epsilon = 1e-7 + expected_prec = tp / (epsilon + tp + fp) + expected_rec = tp / (epsilon + tp + fn) + + labels = labels.astype(np.float32) + predictions = predictions.astype(np.float32) + + with self.test_session() as sess: + # Reshape the data so its easy to queue up: + predictions_batches = predictions.reshape((batch_size, num_batches)) + labels_batches = labels.reshape((batch_size, num_batches)) + + # Enqueue the data: + predictions_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + labels_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + + for i in range(int(num_batches)): + tf_prediction = constant_op.constant(predictions_batches[:, i]) + tf_label = constant_op.constant(labels_batches[:, i]) + sess.run([ + predictions_queue.enqueue(tf_prediction), + labels_queue.enqueue(tf_label) + ]) + + tf_predictions = predictions_queue.dequeue() + tf_labels = labels_queue.dequeue() + + prec, prec_op = metrics.streaming_precision_at_thresholds(tf_predictions, + tf_labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(tf_predictions, + tf_labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + for _ in range(int(num_samples / batch_size)): + sess.run([prec_op, rec_op]) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(expected_prec, prec.eval(), 2) + self.assertAlmostEqual(expected_rec, rec.eval(), 2) + + +class StreamingFPRThresholdsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_false_positive_rate_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0]) + _assert_local_variables(self, ( + 'false_positive_rate_at_thresholds/false_positives:0', + 'false_positive_rate_at_thresholds/true_negatives:0',)) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + fpr, _ = metrics.streaming_false_positive_rate_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [fpr]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + updates_collections=[my_collection_name]) + self.assertListEqual( + ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + thresholds = [0, 0.5, 1.0] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(fpr_op) + + # Then verify idempotency. + initial_fpr = fpr.eval() + for _ in range(10): + self.assertAllClose(initial_fpr, fpr.eval()) + + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(inputs) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertEqual(0, fpr.eval()) + + def testSomeCorrect(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(0.5, fpr.eval()) + + def testAllIncorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(1, fpr.eval()) + + def testWeights1d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds, weights=weights) + + fpr_low = fpr[0] + fpr_high = fpr[1] + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(0.0, fpr_low.eval(), places=5) + self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) + + def testWeights2d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds, weights=weights) + + fpr_low = fpr[0] + fpr_high = fpr[1] - def testAllCorrect(self): - inputs = np.random.randint(0, 2, size=(100, 1)) + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) - predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) - labels = constant_op.constant(inputs) - specificity, update_op = metrics.streaming_sensitivity_at_specificity( - predictions, labels, specificity=0.7) + self.assertAlmostEqual(0.0, fpr_low.eval(), places=5) + self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) + def testExtremeThresholds(self): with self.test_session() as sess: - sess.run(variables.local_variables_initializer()) - self.assertEqual(1, sess.run(update_op)) - self.assertEqual(1, specificity.eval()) + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) + thresholds = [-1.0, 2.0] # lower/higher than any values + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) - def testSomeCorrectHighSpecificity(self): - predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9] - labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + fpr_low = fpr[0] + fpr_high = fpr[1] - predictions = constant_op.constant( - predictions_values, dtype=dtypes_lib.float32) - labels = constant_op.constant(labels_values) - specificity, update_op = metrics.streaming_sensitivity_at_specificity( - predictions, labels, specificity=0.8) + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + self.assertAlmostEqual(1.0, fpr_low.eval(), places=5) + self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) + + def testZeroLabelsPredictions(self): with self.test_session() as sess: + predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) + labels = array_ops.zeros([4]) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.8, sess.run(update_op)) - self.assertAlmostEqual(0.8, specificity.eval()) + sess.run(fpr_op) - def testSomeCorrectLowSpecificity(self): - predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] - labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + self.assertAlmostEqual(0, fpr.eval(), 6) - predictions = constant_op.constant( - predictions_values, dtype=dtypes_lib.float32) - labels = constant_op.constant(labels_values) - specificity, update_op = metrics.streaming_sensitivity_at_specificity( - predictions, labels, specificity=0.4) + def testWithMultipleUpdates(self): + num_samples = 1000 + batch_size = 10 + num_batches = int(num_samples / batch_size) - with self.test_session() as sess: - sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.6, sess.run(update_op)) - self.assertAlmostEqual(0.6, specificity.eval()) + # Create the labels and data. + labels = np.random.randint(0, 2, size=(num_samples, 1)) + noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) + predictions = 0.4 + 0.2 * labels + noise + predictions[predictions > 1] = 1 + predictions[predictions < 0] = 0 + thresholds = [0.3] - def testWeighted(self): - predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] - labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] - weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + fp = 0 + tn = 0 + for i in range(num_samples): + if predictions[i] > thresholds[0]: + if labels[i] == 0: + fp += 1 + else: + if labels[i] == 0: + tn += 1 + epsilon = 1e-7 + expected_fpr = fp / (epsilon + fp + tn) - predictions = constant_op.constant( - predictions_values, dtype=dtypes_lib.float32) - labels = constant_op.constant(labels_values) - weights = constant_op.constant(weights_values) - specificity, update_op = metrics.streaming_sensitivity_at_specificity( - predictions, labels, weights=weights, specificity=0.4) + labels = labels.astype(np.float32) + predictions = predictions.astype(np.float32) with self.test_session() as sess: + # Reshape the data so its easy to queue up: + predictions_batches = predictions.reshape((batch_size, num_batches)) + labels_batches = labels.reshape((batch_size, num_batches)) + + # Enqueue the data: + predictions_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + labels_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + + for i in range(int(num_batches)): + tf_prediction = constant_op.constant(predictions_batches[:, i]) + tf_label = constant_op.constant(labels_batches[:, i]) + sess.run([ + predictions_queue.enqueue(tf_prediction), + labels_queue.enqueue(tf_label) + ]) + + tf_predictions = predictions_queue.dequeue() + tf_labels = labels_queue.dequeue() + + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + tf_predictions, tf_labels, thresholds) + sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.675, sess.run(update_op)) - self.assertAlmostEqual(0.675, specificity.eval()) + for _ in range(int(num_samples / batch_size)): + sess.run(fpr_op) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(expected_fpr, fpr.eval(), 2) -# TODO(nsilberman): Break this up into two sets of tests. -class StreamingPrecisionRecallThresholdsTest(test.TestCase): +class StreamingFNRThresholdsTest(test.TestCase): def setUp(self): np.random.seed(1) ops.reset_default_graph() def testVars(self): - metrics.streaming_precision_at_thresholds( + metrics.streaming_false_negative_rate_at_thresholds( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), thresholds=[0, 0.5, 1.0]) _assert_local_variables(self, ( - 'precision_at_thresholds/true_positives:0', - 'precision_at_thresholds/false_positives:0',)) + 'false_negative_rate_at_thresholds/false_negatives:0', + 'false_negative_rate_at_thresholds/true_positives:0',)) def testMetricsCollection(self): my_collection_name = '__metrics__' - prec, _ = metrics.streaming_precision_at_thresholds( - predictions=array_ops.ones((10, 1)), - labels=array_ops.ones((10, 1)), - thresholds=[0, 0.5, 1.0], - metrics_collections=[my_collection_name]) - rec, _ = metrics.streaming_recall_at_thresholds( + fnr, _ = metrics.streaming_false_negative_rate_at_thresholds( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), thresholds=[0, 0.5, 1.0], metrics_collections=[my_collection_name]) - self.assertListEqual(ops.get_collection(my_collection_name), [prec, rec]) + self.assertListEqual(ops.get_collection(my_collection_name), [fnr]) def testUpdatesCollection(self): my_collection_name = '__updates__' - _, precision_op = metrics.streaming_precision_at_thresholds( - predictions=array_ops.ones((10, 1)), - labels=array_ops.ones((10, 1)), - thresholds=[0, 0.5, 1.0], - updates_collections=[my_collection_name]) - _, recall_op = metrics.streaming_recall_at_thresholds( + _, update_op = metrics.streaming_false_negative_rate_at_thresholds( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), thresholds=[0, 0.5, 1.0], updates_collections=[my_collection_name]) self.assertListEqual( - ops.get_collection(my_collection_name), [precision_op, recall_op]) + ops.get_collection(my_collection_name), [update_op]) def testValueTensorIsIdempotent(self): predictions = random_ops.random_uniform( @@ -2017,25 +2796,21 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): labels = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) - # Run several updates, then verify idempotency. - sess.run([prec_op, rec_op]) - initial_prec = prec.eval() - initial_rec = rec.eval() + # Run several updates. for _ in range(10): - sess.run([prec_op, rec_op]) - self.assertAllClose(initial_prec, prec.eval()) - self.assertAllClose(initial_rec, rec.eval()) + sess.run(fnr_op) + + # Then verify idempotency. + initial_fnr = fnr.eval() + for _ in range(10): + self.assertAllClose(initial_fnr, fnr.eval()) - # TODO(nsilberman): fix tests (passing but incorrect). def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -2043,17 +2818,13 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertEqual(1, prec.eval()) - self.assertEqual(1, rec.eval()) + self.assertEqual(0, fnr.eval()) def testSomeCorrect(self): with self.test_session() as sess: @@ -2061,17 +2832,13 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0.5, prec.eval()) - self.assertAlmostEqual(0.5, rec.eval()) + self.assertAlmostEqual(0.5, fnr.eval()) def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -2080,17 +2847,13 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0, prec.eval()) - self.assertAlmostEqual(0, rec.eval()) + self.assertAlmostEqual(1, fnr.eval()) def testWeights1d(self): with self.test_session() as sess: @@ -2100,27 +2863,17 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): weights = constant_op.constant( [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32) thresholds = [0.5, 1.1] - prec, prec_op = metrics.streaming_precision_at_thresholds( - predictions, labels, thresholds, weights=weights) - rec, rec_op = metrics.streaming_recall_at_thresholds( + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + fnr_low = fnr[0] + fnr_high = fnr[1] sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(1.0, prec_low.eval(), places=5) - self.assertAlmostEqual(0.0, prec_high.eval(), places=5) - self.assertAlmostEqual(1.0, rec_low.eval(), places=5) - self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + self.assertAlmostEqual(0.0, fnr_low.eval(), places=5) + self.assertAlmostEqual(1.0, fnr_high.eval(), places=5) def testWeights2d(self): with self.test_session() as sess: @@ -2130,27 +2883,17 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): weights = constant_op.constant( [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32) thresholds = [0.5, 1.1] - prec, prec_op = metrics.streaming_precision_at_thresholds( - predictions, labels, thresholds, weights=weights) - rec, rec_op = metrics.streaming_recall_at_thresholds( + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + fnr_low = fnr[0] + fnr_high = fnr[1] sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(1.0, prec_low.eval(), places=5) - self.assertAlmostEqual(0.0, prec_high.eval(), places=5) - self.assertAlmostEqual(1.0, rec_low.eval(), places=5) - self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + self.assertAlmostEqual(0.0, fnr_low.eval(), places=5) + self.assertAlmostEqual(1.0, fnr_high.eval(), places=5) def testExtremeThresholds(self): with self.test_session() as sess: @@ -2158,41 +2901,30 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) thresholds = [-1.0, 2.0] # lower/higher than any values - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) + fnr_low = fnr[0] + fnr_high = fnr[1] sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0.75, prec_low.eval()) - self.assertAlmostEqual(0.0, prec_high.eval()) - self.assertAlmostEqual(1.0, rec_low.eval()) - self.assertAlmostEqual(0.0, rec_high.eval()) + self.assertAlmostEqual(0.0, fnr_low.eval()) + self.assertAlmostEqual(1.0, fnr_high.eval()) def testZeroLabelsPredictions(self): with self.test_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0, prec.eval(), 6) - self.assertAlmostEqual(0, rec.eval(), 6) + self.assertAlmostEqual(0, fnr.eval(), 6) def testWithMultipleUpdates(self): num_samples = 1000 @@ -2207,24 +2939,17 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions[predictions < 0] = 0 thresholds = [0.3] - tp = 0 - fp = 0 fn = 0 - tn = 0 + tp = 0 for i in range(num_samples): if predictions[i] > thresholds[0]: if labels[i] == 1: tp += 1 - else: - fp += 1 else: if labels[i] == 1: fn += 1 - else: - tn += 1 epsilon = 1e-7 - expected_prec = tp / (epsilon + tp + fp) - expected_rec = tp / (epsilon + tp + fn) + expected_fnr = fn / (epsilon + fn + tp) labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) @@ -2251,21 +2976,16 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): tf_predictions = predictions_queue.dequeue() tf_labels = labels_queue.dequeue() - prec, prec_op = metrics.streaming_precision_at_thresholds(tf_predictions, - tf_labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(tf_predictions, - tf_labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + tf_predictions, tf_labels, thresholds) sess.run(variables.local_variables_initializer()) for _ in range(int(num_samples / batch_size)): - sess.run([prec_op, rec_op]) + sess.run(fnr_op) # Since this is only approximate, we can't expect a 6 digits match. # Although with higher number of samples/thresholds we should see the # accuracy improving - self.assertAlmostEqual(expected_prec, prec.eval(), 2) - self.assertAlmostEqual(expected_rec, rec.eval(), 2) + self.assertAlmostEqual(expected_fnr, fnr.eval(), 2) # TODO(ptucker): Remove when we remove `streaming_recall_at_k`. @@ -4978,7 +5698,7 @@ class StreamingMeanIOUTest(test.TestCase): sess.run(variables.local_variables_initializer()) for _ in range(5): sess.run(update_op) - desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0, 0.]) + desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0]) self.assertAlmostEqual(desired_output, miou.eval()) def testUpdateOpEvalIsAccumulatedConfusionMatrix(self): @@ -5060,6 +5780,58 @@ class StreamingMeanIOUTest(test.TestCase): desired_miou = np.mean([2. / 4., 4. / 6.]) self.assertAlmostEqual(desired_miou, miou.eval()) + def testMissingClassInLabels(self): + labels = constant_op.constant([ + [[0, 0, 1, 1, 0, 0], + [1, 0, 0, 0, 0, 1]], + [[1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]]]) + predictions = constant_op.constant([ + [[0, 0, 2, 1, 1, 0], + [0, 1, 2, 2, 0, 1]], + [[0, 0, 2, 1, 1, 1], + [1, 1, 2, 0, 0, 0]]]) + num_classes = 3 + with self.test_session() as sess: + miou, update_op = metrics.streaming_mean_iou( + predictions, labels, num_classes) + sess.run(variables.local_variables_initializer()) + self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval()) + self.assertAlmostEqual( + 1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)), + miou.eval()) + + def testMissingClassOverallSmall(self): + labels = constant_op.constant([0]) + predictions = constant_op.constant([0]) + num_classes = 2 + with self.test_session() as sess: + miou, update_op = metrics.streaming_mean_iou( + predictions, labels, num_classes) + sess.run(variables.local_variables_initializer()) + self.assertAllEqual([[1, 0], [0, 0]], update_op.eval()) + self.assertAlmostEqual(1, miou.eval()) + + def testMissingClassOverallLarge(self): + labels = constant_op.constant([ + [[0, 0, 1, 1, 0, 0], + [1, 0, 0, 0, 0, 1]], + [[1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]]]) + predictions = constant_op.constant([ + [[0, 0, 1, 1, 0, 0], + [1, 1, 0, 0, 1, 1]], + [[0, 0, 0, 1, 1, 1], + [1, 1, 1, 0, 0, 0]]]) + num_classes = 3 + with self.test_session() as sess: + miou, update_op = metrics.streaming_mean_iou( + predictions, labels, num_classes) + sess.run(variables.local_variables_initializer()) + self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval()) + self.assertAlmostEqual( + 1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), miou.eval()) + class StreamingConcatTest(test.TestCase): diff --git a/tensorflow/contrib/mpi_collectives/__init__.py b/tensorflow/contrib/mpi_collectives/__init__.py index b94f7b0a353c4c3c698a927d8718bb5b490872cb..9ed16a6f078a506b60fd14f4356ff65a0a692203 100644 --- a/tensorflow/contrib/mpi_collectives/__init__.py +++ b/tensorflow/contrib/mpi_collectives/__init__.py @@ -194,7 +194,7 @@ class DistributedOptimizer(tf.train.Optimizer): See Optimizer.compute_gradients() for more info. - In DistributedOptimizer, compute_gradients() is overriden to also + In DistributedOptimizer, compute_gradients() is overridden to also allreduce the gradients before returning them. """ gradients = (super(DistributedOptimizer, self) diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index d6508362b8bf01468a43b26d6a0d0c9807b5967e..5e7263ff625c0570ab31931ba8ee886b38142206 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -71,10 +71,12 @@ tf_kernel_library( "kernels/nccl_manager.cc", "kernels/nccl_manager.h", "kernels/nccl_ops.cc", + "kernels/nccl_rewrite.cc", ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:proto_text", "@nccl_archive//:nccl", ], alwayslink = 1, diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc index 4eb52492dbcc386941029709631314634c1c9be1..266d4f6f0de0274dca2bfc9022503f09b0ca7d42 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_ops.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc @@ -15,8 +15,6 @@ limitations under the License. #if GOOGLE_CUDA -#include -#include #include #include "src/nccl.h" @@ -24,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace { // Base class for all communicator ops that use nccl. // @@ -134,7 +133,7 @@ class NcclReduceSendKernel : public NcclReduceOpBase { compute_stream, &c->input(0), std::move(actual_done)); } }; -REGISTER_KERNEL_BUILDER(Name("NcclReduceSend").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclReduceSend").Device(DEVICE_GPU), NcclReduceSendKernel); // To execute a single reduce, this kernel is called once for one devices, and @@ -166,7 +165,7 @@ class NcclReduceRecvKernel : public NcclReduceOpBase { private: ncclRedOp_t reduction_op_; }; -REGISTER_KERNEL_BUILDER(Name("NcclReduceRecv").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclReduceRecv").Device(DEVICE_GPU), NcclReduceRecvKernel); // To execute a single broadcast, this kernel is called once for one device, and @@ -191,7 +190,7 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase { std::move(actual_done)); } }; -REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclBroadcastSend").Device(DEVICE_GPU), NcclBroadcastSendKernel); // To execute a single broadcast, this kernel is called once for all but one of @@ -206,7 +205,7 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase { const Tensor& shape_t = c->input(0); TensorShape shape; OP_REQUIRES_OK_ASYNC( - c, TensorShapeUtils::MakeShape(shape_t.vec(), &shape), done); + c, TensorShapeUtils::MakeShape(shape_t.vec(), &shape), done); Tensor* out_t; OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape, &out_t), done); @@ -224,9 +223,24 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase { } }; REGISTER_KERNEL_BUILDER( - Name("NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"), + Name("_NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"), NcclBroadcastRecvKernel); +// Define stub kernels for the ops that get replaced post placement. +class NcclStubKernel : public AsyncOpKernel { + public: + explicit NcclStubKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {} + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + c->SetStatus(errors::Unimplemented( + "This op should be replaced during graph optimization.")); + done(); + } +}; +REGISTER_KERNEL_BUILDER(Name("NcclBroadcast").Device(DEVICE_GPU), + NcclStubKernel); +REGISTER_KERNEL_BUILDER(Name("NcclReduce").Device(DEVICE_GPU), NcclStubKernel); + +} // namespace } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc new file mode 100644 index 0000000000000000000000000000000000000000..94a77c59da719c5f8e155cde0580d00a7e6bd3d5 --- /dev/null +++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc @@ -0,0 +1,271 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include +#include + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/node_builder.h" + +namespace tensorflow { +namespace { + +// Replaces NcclReduce node with _NcclReduceRecv reusing one input of same +// device, adds one _NcclReduceSend for each other input. +Status ReplaceReduce(Graph* graph, Node* node) { + string reduction; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "reduction", &reduction)); + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype)); + int num_devices = node->num_inputs(); + string shared_name = node->name(); + auto make_builder = [&](StringPiece op_name, StringPiece suffix) { + return NodeBuilder(strings::StrCat(shared_name, suffix), op_name) + .Attr("reduction", reduction) + .Attr("num_devices", num_devices) + .Attr("shared_name", shared_name) + .Attr("T", dtype); + }; + std::vector control_inputs; + for (const auto& edge : node->in_edges()) { + if (edge->IsControlEdge()) { + control_inputs.push_back(edge->src()); + } + } + std::vector out_nodes; + for (const auto& edge : node->out_edges()) { + out_nodes.emplace_back(edge->dst(), edge->dst_input()); + } + int recv_dev = node->assigned_device_name_index(); + NodeBuilder recv_builder = + make_builder("_NcclReduceRecv", "Recv").ControlInputs(control_inputs); + bool recv_input_set = false; + int send_counter = 0; + for (const auto& edge : node->in_edges()) { + Node* src_node = edge->src(); + if (edge->IsControlEdge()) { + continue; + } + int send_dev = src_node->assigned_device_name_index(); + if (!recv_input_set && send_dev == recv_dev) { + recv_builder.Input(src_node); + recv_input_set = true; + continue; + } + auto send_builder = make_builder("_NcclReduceSend", + strings::StrCat("Send_", ++send_counter)) + .Input(src_node) + .ControlInputs(control_inputs); + Node* send_node = nullptr; + TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node)); + send_node->set_assigned_device_name_index(send_dev); + // Send nodes don't have any outputs and therefore have no data dependencies + // to the outputs of the graph. We add a control dependency to the receive + // node so that those 'dangling' nodes are run. + // TODO(b/67027412): Avoid these cross-device control edges. + for (const auto& out_node : out_nodes) { + graph->AddControlEdge(send_node, out_node.node); + } + } + if (!recv_input_set) { + return errors::InvalidArgument( + "No input tensor uses the same device as the NcclReduce op"); + } + Node* recv_node = nullptr; + TF_RETURN_IF_ERROR(recv_builder.Finalize(graph, &recv_node)); + recv_node->set_assigned_device_name_index(recv_dev); + graph->RemoveNode(node); + for (const auto& out_node : out_nodes) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(recv_node, out_node.node); + } else { + graph->AddEdge(recv_node, 0, out_node.node, out_node.index); + } + } + return Status::OK(); +} + +TensorProto TensorFromShape(const TensorShapeProto& shape) { + TensorProto result; + result.set_dtype(DT_INT32); + for (const auto& dim : shape.dim()) { + result.add_int_val(dim.size()); + } + result.mutable_tensor_shape()->add_dim()->set_size(shape.dim_size()); + return result; +} + +// Replaces NcclBroadcast node with _NcclBroadcastSend, connects the input to +// all outputs of same device, adds one _NcclBroadcastRecv for each other output +// device. +Status ReplaceBroadcast(Graph* graph, Node* node) { + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype)); + int send_dev = node->assigned_device_name_index(); + int num_devices = 0; // Number of distinct devices, incremented below. + + // Map device name index to nodes that take the broadcast as input. + std::vector> out_nodes_map; + for (const auto& edge : node->out_edges()) { + int dst_dev = edge->IsControlEdge() + ? send_dev + : edge->dst()->assigned_device_name_index(); + if (out_nodes_map.size() <= dst_dev) { + out_nodes_map.resize(dst_dev + 1); + } + auto it = out_nodes_map.begin() + dst_dev; + if (it->empty()) { + ++num_devices; + } + it->emplace_front(NodeBuilder::NodeOut(edge->dst(), edge->dst_input())); + } + + if (num_devices <= 1) { + // Only one participating device, skip NCCL op. + const Edge* in_edge = nullptr; + TF_RETURN_IF_ERROR(node->input_edge(0, &in_edge)); + Node* in_node = in_edge->src(); + int in_index = in_edge->src_output(); + graph->RemoveNode(node); + for (const auto& out_nodes : out_nodes_map) { + for (const auto& out_node : out_nodes) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(in_node, out_node.node); + } else { + graph->AddEdge(in_node, in_index, out_node.node, out_node.index); + } + } + } + return Status::OK(); + } + + string shared_name = node->name(); + auto make_builder = [&](StringPiece op_name, StringPiece suffix) { + return NodeBuilder(strings::StrCat(shared_name, suffix), op_name) + .Attr("num_devices", num_devices) + .Attr("shared_name", shared_name) + .Attr("T", dtype); + }; + + // Create broadcast send node and replace the original broadcast node. + NodeBuilder::NodeOut in_node; + NodeBuilder send_builder = make_builder("_NcclBroadcastSend", "Send"); + for (const auto& edge : node->in_edges()) { + if (edge->IsControlEdge()) { + send_builder.ControlInput(edge->src()); + } else { + in_node = NodeBuilder::NodeOut(edge->src(), edge->src_output()); + send_builder.Input(in_node); + } + } + Node* send_node = nullptr; + TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node)); + send_node->set_assigned_device_name_index(send_dev); + + TensorShapeProto shape_proto; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "shape", &shape_proto)); + + // Delete the original node before reconnecting to outputs. + graph->RemoveNode(node); + + // Connect all outputs on the device of broadcast send. + for (const auto& out_node : out_nodes_map[send_dev]) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(send_node, out_node.node); + } else { + graph->AddEdge(in_node.node, in_node.index, out_node.node, + out_node.index); + // Add control edge so send node is run. + graph->AddControlEdge(send_node, out_node.node); + } + } + out_nodes_map[send_dev].clear(); + + TensorProto tensor_proto = TensorFromShape(shape_proto); + bool is_fully_defined = TensorShape(shape_proto).IsFullyDefined(); + string shape_name = strings::StrCat(in_node.node->name(), "/Shape"); + Node* shape_node = nullptr; + if (!is_fully_defined) { + NodeBuilder shape_builder(shape_name, "Shape"); + shape_builder.Input(in_node).Attr("out_type", DT_INT32).Attr("T", dtype); + TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node)); + shape_node->set_assigned_device_name_index(send_dev); + } + + // For all other devices, create a broadcast receive and connect outputs. + for (int recv_dev = 0; recv_dev < out_nodes_map.size(); ++recv_dev) { + if (out_nodes_map[recv_dev].empty()) { + continue; + } + if (is_fully_defined) { + // If the shape is fully defined, define one const node per device. + NodeBuilder shape_builder(strings::StrCat(shape_name, recv_dev), "Const"); + shape_builder.Attr("value", tensor_proto).Attr("dtype", DT_INT32); + TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node)); + shape_node->set_assigned_device_name_index(recv_dev); + } + Node* recv_node; + TF_RETURN_IF_ERROR( + make_builder("_NcclBroadcastRecv", strings::StrCat("Recv_", recv_dev)) + .Input(shape_node) + .Finalize(graph, &recv_node)); + recv_node->set_assigned_device_name_index(recv_dev); + for (const auto& out_node : out_nodes_map[recv_dev]) { + graph->AddEdge(recv_node, 0, out_node.node, out_node.index); + } + } + + return Status::OK(); +} + +// Replaces occurrences of Nccl{Reduce, Broadcast}Input/Output with their +// _Nccl...Send/Recv counterparts and removes data dependencies between them. +class NcclReplacePass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override { + if (options.graph == nullptr) { + return Status::OK(); + } + Graph* graph = options.graph->get(); + if (graph == nullptr) { + return errors::Internal( + "NCCL replacement should happen before partitioning and a " + "graph should be available."); + } + // Find reduction and broadcast ops and replace them with Send/Recv ops. + for (Node* node : graph->op_nodes()) { + StringPiece type = node->type_string(); + if (!type.starts_with("Nccl")) { + continue; + } + if (type == "NcclReduce") { + TF_RETURN_IF_ERROR(ReplaceReduce(graph, node)); + } + if (type == "NcclBroadcast") { + TF_RETURN_IF_ERROR(ReplaceBroadcast(graph, node)); + } + } + return Status::OK(); + } +}; +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0, + NcclReplacePass); + +} // namespace +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc index 532c79c24cc9596af580ee3faf463aecbc59bb07..8eb804c2e988f313ba1b340217cae20f1f5502c7 100644 --- a/tensorflow/contrib/nccl/ops/nccl_ops.cc +++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc @@ -45,7 +45,28 @@ num_devices: The number of devices participating in this reduction. shared_name: Identifier that shared between ops of the same reduction. )doc"); -REGISTER_OP("NcclReduceSend") +// Note: This op has no kernel implementation, but is replaced by +// _NcclReduceSend and _NcclReduceRecv during graph optimization stage. +REGISTER_OP("NcclReduce") + .Input("input: num_devices * T") + .Output("data: T") + .Attr("reduction: {'min', 'max', 'prod', 'sum'}") + .Attr("T: {float, float64, int32, int64}") + .Attr("num_devices: int") + .SetIsStateful() + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Reduces `input` from `num_devices` using `reduction` to a single device. + +The graph should be constructed so that all inputs have a valid device +assignment, and the op itself is assigned one of these devices. + +input: The input to the reduction. +data: the value of the reduction across all `num_devices` devices. +reduction: the reduction operation to perform. + )doc"); + +REGISTER_OP("_NcclReduceSend") .Input("input: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") .Attr("T: {float, float64, int32, int64}") @@ -54,19 +75,20 @@ REGISTER_OP("NcclReduceSend") .SetIsStateful() .SetShapeFn(shape_inference::NoOutputs) .Doc(R"doc( -Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`. +Replacement node for NcclReduce. +Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`. The graph should be constructed so that 'num_devices-1' devices run -`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value +`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the reduction +input: The input to the reduction. reduction: the reduction operation to perform. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same reduce. )doc"); -REGISTER_OP("NcclReduceRecv") +REGISTER_OP("_NcclReduceRecv") .Input("input: T") .Output("data: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") @@ -76,21 +98,42 @@ REGISTER_OP("NcclReduceRecv") .SetIsStateful() .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( +Replacement node for NcclReduce. + Reduces 'input' from this op and the NcclReduceSend ops registered in the same `shared_name`. - The graph should be constructed so that 'num_devices-1' devices run -`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value +`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the reduction +input: The input to the reduction. data: The reduced data received from this op and the NcclReduceSend op. reduction: the reduction operation to perform. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same reduce. )doc"); -REGISTER_OP("NcclBroadcastSend") +// Note: This op has no kernel implementation, but is replaced by +// _NcclBroadcastSend and _NcclBroadcastRecv during graph optimization stage. +REGISTER_OP("NcclBroadcast") + .Input("input: T") + .Output("output: T") + .Attr("T: {float, float64, int32, int64}") + .Attr("shape: shape") + .SetIsStateful() + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Sends `input` to all devices that are connected to the output. + +The graph should be constructed so that all ops connected to the output have a +valid device assignment, and the op itself is assigned one of these devices. + +input: The input to the broadcast. +output: The same as input. +shape: The shape of the input tensor. + )doc"); + +REGISTER_OP("_NcclBroadcastSend") .Input("input: T") .Attr("T: {float, float64, int32, int64}") .Attr("num_devices: int") @@ -98,19 +141,21 @@ REGISTER_OP("NcclBroadcastSend") .SetIsStateful() .SetShapeFn(shape_inference::NoOutputs) .Doc(R"doc( -Sends `input` to the NcclBroadcastRecv ops registered in the same `shared_name`. +Replacement node for NcclBroadcast. -The graph should be constructed so that one device runs `NcclBroadcastSend` and -`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`. +Sends `input` to the _NcclBroadcastRecv ops registered in the same +`shared_name`. +The graph should be constructed so that one device runs `_NcclBroadcastSend` and +`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the broadcast +input: The input to the broadcast. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same broadcast. )doc"); -REGISTER_OP("NcclBroadcastRecv") - .Input("shape: int64") +REGISTER_OP("_NcclBroadcastRecv") + .Input("shape: int32") .Output("output: T") .Attr("T: {float, float64, int32, int64}") .Attr("num_devices: int") @@ -123,11 +168,12 @@ REGISTER_OP("NcclBroadcastRecv") return Status::OK(); }) .Doc(R"doc( -Sends data of shape `shape` from the NcclBroadcastSend op registered in the -same `shared_name`. +Replacement node for NcclBroadcast. -The graph should be constructed so that one device runs `NcclBroadcastSend` and -`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`. +Sends data of shape `shape` from the _NcclBroadcastSend op registered in the +same `shared_name`. +The graph should be constructed so that one device runs `_NcclBroadcastSend` and +`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. shape: The shape of the output. diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index 906d9f948acf212dce1dbbbf9ec7c60c30f389b1..8dc038b9ac992de7db8b762e3697c6693099e192 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -23,9 +23,7 @@ from tensorflow.contrib.nccl.ops import gen_nccl_ops from tensorflow.contrib.util import loader from tensorflow.python.eager import context from tensorflow.python.framework import device -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.platform import resource_loader _nccl_ops_so = loader.load_op_library( @@ -64,13 +62,13 @@ def _all_sum_grad(op, grad): LookupError: If `reduction` is not `sum`. """ if op.get_attr('reduction') != 'sum': - raise LookupError('No gradient defined for NcclAllReduce except all_sum.') + raise LookupError('No gradient defined for NcclAllReduce except sum.') - _check_device_assignment(grad) + _check_device(grad, expected=op.device) num_devices = op.get_attr('num_devices') shared_name = op.get_attr('shared_name') + '_grad' - with ops.device(grad.device): + with ops.device(op.device): return gen_nccl_ops.nccl_all_reduce( input=grad, reduction='sum', @@ -129,7 +127,7 @@ def all_max(tensors): return _apply_all_reduce('max', tensors) -def reduce_sum(tensors, dst_device): +def reduce_sum(tensors): """Returns a tensor with the reduce sum across `tensors`. The computation is done with a reduce operation, so only one tensor is @@ -138,54 +136,76 @@ def reduce_sum(tensors, dst_device): Args: tensors: The input tensors across which to sum; must be assigned to GPU devices. - dst_device: The device of the returned tensor. Returns: - A tensor containing the sum of the input tensors, with the device of the - tensor being `dst_device`. + A tensor containing the sum of the input tensors. + + Raises: + LookupError: If context is not currently using a GPU device. + """ + return _apply_reduce('sum', tensors) + + +@ops.RegisterGradient('NcclReduce') +def _reduce_sum_grad(op, grad): + """The gradients for input `Operation` of `reduce_sum`. + + Args: + op: The `sum send` `Operation` that we are differentiating. + grad: Gradient with respect to the output of the `reduce_sum` op. + + Returns: + The gradient with respect to the input of `reduce_sum` op. + + Raises: + LookupError: If the reduction attribute of op is not `sum`. """ - return _apply_reduce('sum', tensors, dst_device) + if op.get_attr('reduction') != 'sum': + raise LookupError('No gradient defined for NcclReduce except sum.') + _check_device(grad, expected=op.device) + with ops.device(op.device): + result = gen_nccl_ops.nccl_broadcast(input=grad, shape=grad.shape) -def broadcast(src_tensor, dst_devices): - """Returns a list of tensors on `dst_devices`, each with value `tensor`. + return [result] * len(op.inputs) - The computation is done with a broadcast nccl operation, so if only some of - the returned tensors and src_tensor are evaluated then the computation will - hang. + +def broadcast(tensor): + """Returns a tensor that can be efficiently transferred to other devices. Args: - src_tensor: The tensor to send; must be assigned to a GPU device. - dst_devices: The GPU devices to receive the sent tensor. + tensor: The tensor to send; must be assigned to a GPU device. Returns: - An `Operation` to send the `src_tensor`, and a list of tensors, each with - the value of `src_tensor`, where the device of tensor i is `dst_devices[i]`. + A tensor with the value of `src_tensor`, which can be used as input to + ops on other GPU devices. """ - if not dst_devices: - raise ValueError('Must pass >0 dst_devices to broadcast') _check_graph_mode() - _check_device_assignment(src_tensor) + _check_device(tensor) - shape = array_ops.shape(src_tensor, out_type=dtypes.int64) - num_devices = len(dst_devices) + 1 - shared_name = _get_shared_name() + with ops.device(tensor.device): + return gen_nccl_ops.nccl_broadcast(input=tensor, shape=tensor.shape) - with ops.device(src_tensor.device): - send = gen_nccl_ops.nccl_broadcast_send( - input=src_tensor, num_devices=num_devices, shared_name=shared_name) - - recvs = [] - for d in dst_devices: - with ops.device(d): - recvs.append( - gen_nccl_ops.nccl_broadcast_recv( - shape=shape, - T=src_tensor.dtype, - num_devices=num_devices, - shared_name=shared_name)) - return send, recvs +@ops.RegisterGradient('NcclBroadcast') +def _broadcast_grad(op, accumulated_grad): + """The gradients for input `Operation` of `broadcast`. + + Args: + op: The `broadcast send` `Operation` that we are differentiating. + accumulated_grad: Accumulated gradients with respect to the output of the + `broadcast` op. + + Returns: + Gradients with respect to the input of `broadcast`. + """ + # Grab inputs of accumulated_grad and replace accumulation with reduce_sum. + grads = [t for t in accumulated_grad.op.inputs] + for t in grads: + _check_device(t) + + with ops.device(op.device): + return gen_nccl_ops.nccl_reduce(input=grads, reduction='sum') def _apply_all_reduce(reduction, tensors): @@ -198,7 +218,7 @@ def _apply_all_reduce(reduction, tensors): res = [] for t in tensors: - _check_device_assignment(t) + _check_device(t) with ops.device(t.device): res.append( gen_nccl_ops.nccl_all_reduce( @@ -210,40 +230,20 @@ def _apply_all_reduce(reduction, tensors): return res -def _apply_reduce(reduction, tensors, dst_device): +def _apply_reduce(reduction, tensors): """Helper function for reduce_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to reduce operations') - if not dst_device: - raise ValueError('Must pass dst_device to reduce operations') _check_graph_mode() + for t in tensors: + _check_device(t) + result = gen_nccl_ops.nccl_reduce(input=tensors, reduction=reduction) try: - recv_index = next(i for i, t in enumerate(tensors) - if t.device == dst_device) + next(t for t in tensors if t.device == result.device) except StopIteration: - raise ValueError('One of the tensors must be assigned to dst_device') - shared_name = _get_shared_name() - - sends = [] - for t in tensors[:recv_index] + tensors[recv_index + 1:]: - _check_device_assignment(t) - with ops.device(t.device): - sends.append( - gen_nccl_ops.nccl_reduce_send( - input=t, - reduction=reduction, - num_devices=len(tensors), - shared_name=shared_name)) - - with ops.device(dst_device): - recv = gen_nccl_ops.nccl_reduce_recv( - input=tensors[recv_index], - reduction=reduction, - num_devices=len(tensors), - shared_name=shared_name) - - return recv, sends + raise ValueError('One input tensor must be assigned to current device') + return result _lock = threading.Lock() @@ -259,9 +259,11 @@ def _get_shared_name(): return 'c%s' % val -def _check_device_assignment(tensor): +def _check_device(tensor, expected=None): if not device.canonical_name(tensor.device): raise ValueError('Device assignment required for nccl collective ops') + if expected and expected != tensor.device: + raise ValueError('Expected device %s, got %s' % (expected, tensor.device)) def _check_graph_mode(): diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py index 96d67723a0ad197436a12924bd2b4ecb73eee4cb..255409303a578cd3a3b69c3c8d0e3464e58f08bb 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py @@ -22,8 +22,10 @@ from functools import partial import numpy as np from tensorflow.contrib import nccl +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients from tensorflow.python.platform import test @@ -36,27 +38,30 @@ def _DeviceTensors(tensors, devices): def _NcclAllReduce(nccl_fun, tensors, devices): - return nccl_fun(_DeviceTensors(tensors, devices)), [] + return nccl_fun(_DeviceTensors(tensors, devices)) def _NcclReduce(nccl_fun, tensors, devices): - d_tensors = _DeviceTensors(tensors, devices) receiver = np.random.randint(0, len(devices)) - received_tensor, send_ops = nccl_fun(d_tensors, devices[receiver]) - return [received_tensor], send_ops + with ops.device(devices[receiver]): + return [nccl_fun(_DeviceTensors(tensors, devices))] def _NcclBroadcast(tensors, devices): sender = np.random.randint(0, len(devices)) - d_tensor = _DeviceTensors(tensors[0:1], devices[sender:sender + 1])[0] - other_devices = devices[:sender] + devices[sender + 1:] - send_op, received_tensors = nccl.broadcast(d_tensor, other_devices) - return received_tensors, [send_op] + with ops.device(devices[sender]): + tensor = array_ops.identity(tensors[0]) + broadcast = nccl.broadcast(tensor) + return _DeviceTensors([broadcast] * len(devices), devices) class NcclTestCase(test.TestCase): - def _Test(self, nccl_reduce, numpy_fn): + def _Test(self, + nccl_reduce, + numpy_fn, + device_sets=(['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], + ['/device:GPU:1', '/device:GPU:0'])): """Tests that nccl_reduce does the same as reduction with numpy_fn. Args: @@ -65,6 +70,7 @@ class NcclTestCase(test.TestCase): reduction. numpy_fn: A function taking two tensors and returning the reduction of the two. + device_sets: Tuple of virtual devices to run test on. """ if not test.is_gpu_available(): return # Test requires access to a GPU @@ -74,26 +80,28 @@ class NcclTestCase(test.TestCase): # same communicator across multiple sessions. with self.test_session(use_gpu=True) as sess: - for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], - ['/device:GPU:1', '/device:GPU:0']]: + for devices in device_sets: shape = (3, 4) random = (np.random.random_sample(shape) - .5) * 1024 - tensors = [random.astype(dtype)] * len(devices) + tensors = [] + for _ in devices: + tensors.append(random.astype(dtype)) np_ans = tensors[0] for t in tensors[1:]: np_ans = numpy_fn(np_ans, t) - reduce_tensors, reduce_ops = nccl_reduce(tensors, devices) + reduce_tensors = nccl_reduce(tensors, devices) self.assertNotEmpty(reduce_tensors) # Test shape inference. for r in reduce_tensors: self.assertEqual(shape, r.get_shape()) + result_tensors = [array_ops.identity(t) for t in reduce_tensors] + # Test execution and results. - nccl_results = sess.run(reduce_tensors + reduce_ops) - for r in nccl_results[:len(reduce_tensors)]: - self.assertAllClose(r, np_ans) + for t in sess.run(result_tensors): + self.assertAllClose(t, np_ans) def _TestGradient(self, nccl_reduce, numpy_fn): """Tests the gradient of nccl_reduce. @@ -106,14 +114,11 @@ class NcclTestCase(test.TestCase): reduction of the two. """ def _Gradient(tensors, devices): - reduce_tensors, _ = nccl_reduce(tensors, devices) - tensor_ops = [t.op for t in reduce_tensors] - d_tensors = _DeviceTensors(tensors, devices) - grad_tensors = [ - ops.get_gradient_function(op)(op, loss) - for op, loss in zip(tensor_ops, d_tensors) - ] - return grad_tensors, [] + inputs = [array_ops.placeholder(t.dtype, t.shape) for t in tensors] + reduce_tensors = nccl_reduce(inputs, devices) + losses = _DeviceTensors(tensors, [t.device for t in reduce_tensors]) + grads = gradients.gradients(reduce_tensors, inputs, losses) + return [g for g in grads if g is not None] self._Test(_Gradient, numpy_fn) @@ -142,27 +147,43 @@ class SingleReduceTest(NcclTestCase): def testSum(self): self._Test(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x + y) + def testSumGrad(self): + self._TestGradient(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x) + class BroadcastTest(NcclTestCase): def testBroadcast(self): self._Test(_NcclBroadcast, lambda x, y: x) + def testBroadcastSingleDevice(self): + # Broadcasts on a single device are removed completely during rewrite. + self._Test(_NcclBroadcast, lambda x, y: x, + (['/device:GPU:0', '/device:GPU:0'])) + + def testBroadcastToCpuError(self): + # Broadcasts to CPU is not supported. + with self.assertRaisesRegexp( + errors.NotFoundError, + "No registered '_NcclBroadcastRecv' OpKernel for CPU devices"): + self._Test(_NcclBroadcast, lambda x, y: x, + (['/device:GPU:0', '/device:CPU:0'])) + + def testBroadcastGrad(self): + self._TestGradient(_NcclBroadcast, lambda x, y: x + y) + class CombinedTest(NcclTestCase): """Test all-reduce vs. single-reduce plus broadcast in one session.run.""" - def _combined(self, tensors, devices): - all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)[0] - single_reduce_tensors, single_reduce_ops = _NcclReduce( - nccl.reduce_sum, tensors, devices) - broadcast_tensors, broadcast_ops = _NcclBroadcast(single_reduce_tensors, - devices) - all_tensors = all_reduce_tensors + single_reduce_tensors + broadcast_tensors - return all_tensors, single_reduce_ops + broadcast_ops + def _Combined(self, tensors, devices): + all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices) + single_reduce_tensors = _NcclReduce(nccl.reduce_sum, tensors, devices) + broadcast_tensors = _NcclBroadcast(single_reduce_tensors, devices) + return all_reduce_tensors + broadcast_tensors def testCombined(self): - self._Test(self._combined, lambda x, y: x + y) + self._Test(self._Combined, lambda x, y: x + y) if __name__ == '__main__': diff --git a/tensorflow/contrib/nn/__init__.py b/tensorflow/contrib/nn/__init__.py index 89b70ddfc203c5e4102f665932ea2d726c495b7e..6cc825e47daa956693c0b813860b0c639d54e313 100644 --- a/tensorflow/contrib/nn/__init__.py +++ b/tensorflow/contrib/nn/__init__.py @@ -20,6 +20,7 @@ @@deprecated_flipped_sigmoid_cross_entropy_with_logits @@rank_sampled_softmax_loss @@sampled_sparse_softmax_loss +@@scaled_softplus """ from __future__ import absolute_import diff --git a/tensorflow/contrib/nn/python/ops/scaled_softplus.py b/tensorflow/contrib/nn/python/ops/scaled_softplus.py index 5fc11d8ec66fc6e32d5028a6b6181c784104db06..fcbfbc239ca5b8a1d4b17b403f99b7eb05db47b0 100644 --- a/tensorflow/contrib/nn/python/ops/scaled_softplus.py +++ b/tensorflow/contrib/nn/python/ops/scaled_softplus.py @@ -20,58 +20,96 @@ from __future__ import print_function from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -def scaled_softplus(x, alpha, name=None): - """Returns `alpha * ln(1 + exp(x / alpha))`, for scalar `alpha > 0`. +def _reduce_and_reshape_grad(g, t): + """Returns the gradient, sum-reduced and reshaped to `t`'s shape.""" + shape = array_ops.shape(t) + g_shape = array_ops.shape(g) + # pylint: disable=protected-access + bcast_dims, _ = gen_array_ops._broadcast_gradient_args(shape, g_shape) + # pylint: enable=protected-access + return array_ops.reshape(math_ops.reduce_sum(g, bcast_dims), shape) + + +def scaled_softplus(x, alpha, clip=None, name=None): + """Returns `y = alpha * ln(1 + exp(x / alpha))` or `min(y, clip)`. This can be seen as a softplus applied to the scaled input, with the output appropriately scaled. As `alpha` tends to 0, `scaled_softplus(x, alpha)` tends - to `relu(x)`. + to `relu(x)`. The clipping is optional. As alpha->0, scaled_softplus(x, alpha) + tends to relu(x), and scaled_softplus(x, alpha, clip=6) tends to relu6(x). Note: the gradient for this operation is defined to depend on the backprop inputs as well as the outputs of this operation. Args: x: A `Tensor` of inputs. - alpha: A scalar `Tensor`, indicating the amount of smoothness. The caller + alpha: A `Tensor`, indicating the amount of smoothness. The caller must ensure that `alpha > 0`. + clip: (optional) A `Tensor`, the upper bound to clip the values. name: A name for the scope of the operations (optional). Returns: - A tensor of same size and type as `x`. + A tensor of the size and type determined by broadcasting of the inputs. """ - with ops.name_scope(name, 'scaled_softplus', [x, alpha]): + clipping = clip is not None + with ops.name_scope(name, 'scaled_softplus', + [x, alpha] + ([clip] if clipping else [])): x = ops.convert_to_tensor(x, name='x') dtype = x.dtype alpha = ops.convert_to_tensor(alpha, dtype=dtype, name='alpha') - # Verify that alpha is a scalar. - alpha.get_shape().assert_has_rank(0) + # Compute the forward value. + y = alpha * nn.softplus(x / alpha) + if clipping: + clip = ops.convert_to_tensor(clip, dtype=dtype, name='clip') + y = math_ops.minimum(y, clip) def _grad(op, g): - """Backprop for scaled softplus.""" - y = op.outputs[0] - alpha = op.inputs[1] - # Prevent the expensive computations from happening before g is available. + """Backprop for scaled softplus, with optional clipping.""" + y, x, alpha = op.inputs[:3] + # Prevent the memory-expensive computations from happening before g is + # available. with ops.control_dependencies([g]): - y /= alpha + y = array_ops.identity(y) + clip_grad = [] + if clipping: + clip = op.inputs[3] + unclipped = math_ops.cast(y < clip, g.dtype) + clip_grad = [_reduce_and_reshape_grad(g * (1. - unclipped), clip)] + g *= unclipped + y /= alpha emy = math_ops.exp(-y) dy_dx = 1. - emy # The eps below avoids log(0). Note that t*log(t) -> 0 as t->0. eps = 1e-8 dy_dalpha = y * emy - dy_dx * math_ops.log(dy_dx + eps) - return g * dy_dx, math_ops.reduce_sum(g * dy_dalpha) + # Backprop to the actual inputs, but not to the output. + return [None, + _reduce_and_reshape_grad(g * dy_dx, x), + _reduce_and_reshape_grad(g * dy_dalpha, alpha)] + clip_grad - @function.Defun(dtype, dtype, - func_name='ScaledSoftplus_%s' % dtype.name, - shape_func=lambda op: [op.inputs[0].get_shape()], + if clipping: + @function.Defun(dtype, dtype, dtype, dtype, + func_name='ScaledSoftplusHelper_clip_%s' % dtype.name, + shape_func=lambda op: [op.inputs[0].shape], + python_grad_func=_grad) + def _forward_helper_clip(y, x, alpha, clip): + del x, alpha, clip # Unused. + return y + return _forward_helper_clip(y, x, alpha, clip) + # No clipping. + @function.Defun(dtype, dtype, dtype, + func_name='ScaledSoftplusHelper_%s' % dtype.name, + shape_func=lambda op: [op.inputs[0].shape], python_grad_func=_grad) - def _forward(x, alpha): - """Forward computation of scaled softplus.""" - return alpha * nn.softplus(x / alpha) - - return _forward(x, alpha) + def _forward_helper(y, x, alpha): + del x, alpha # Unused. + return y + return _forward_helper(y, x, alpha) diff --git a/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py b/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py index 3a459330ceb774b033ff6622b4c90807c782f06f..b978343c6a79af856d833b0ab8002c256ce478e0 100644 --- a/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py +++ b/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py @@ -33,10 +33,11 @@ class ScaledSoftplusTest(test.TestCase): x = np.random.randn(3, 4).astype(np.float32) x64 = np.random.randn(3, 4).astype(np.float64) alpha = np.random.rand() + 0.01 - y = alpha * np.log(1. + np.exp(x / alpha)) + clip = np.float32(0.1) + y = np.minimum(alpha * np.log(1. + np.exp(x / alpha)), clip) y64 = alpha * np.log(1. + np.exp(x64 / alpha)) with self.test_session(use_gpu=True) as sess: - z = scaled_softplus(constant_op.constant(x), alpha) + z = scaled_softplus(constant_op.constant(x), alpha, clip) z64 = scaled_softplus(constant_op.constant(x64), alpha) z, z64 = sess.run([z, z64]) eps = 1e-6 @@ -47,18 +48,28 @@ class ScaledSoftplusTest(test.TestCase): np.random.seed(1) # Make it reproducible. x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) - alpha_np = np.float32(np.random.rand() + 0.01) + alpha_np = np.float32(np.random.rand(1, x_shape[1]) + 0.01) + clip_np = np.float32(np.random.rand(x_shape[0], 1) * 5.) with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np) alpha_tf = constant_op.constant(alpha_np) + clip_tf = constant_op.constant(clip_np) y_tf = scaled_softplus(x_tf, alpha_tf) + z_tf = scaled_softplus(x_tf, alpha_tf, clip_tf * 0.1) err = gradient_checker.compute_gradient_error([x_tf, alpha_tf], - [x_shape, []], + [x_shape, alpha_np.shape], y_tf, x_shape, [x_np, alpha_np], - delta=1e-2) - eps = 1e-4 + delta=0.002) + err_clip = gradient_checker.compute_gradient_error( + [x_tf, alpha_tf, clip_tf], + [x_shape, alpha_np.shape, clip_np.shape], + z_tf, x_shape, + [x_np, alpha_np, clip_np], + delta=0.002) + eps = 2e-4 self.assertLess(err, eps) + self.assertLess(err_clip, eps) if __name__ == '__main__': diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..7ff186bc2ad7204d934c322a04ad1c3f2aa383ab --- /dev/null +++ b/tensorflow/contrib/quantize/BUILD @@ -0,0 +1,209 @@ +package(default_visibility = ["//tensorflow:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "common", + srcs = ["python/common.py"], + srcs_version = "PY2AND3", + deps = [], +) + +py_library( + name = "input_to_ops", + srcs = ["python/input_to_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":common", + ], +) + +py_test( + name = "input_to_ops_test", + size = "small", + srcs = ["python/input_to_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":input_to_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + +py_library( + name = "fold_batch_norms", + srcs = ["python/fold_batch_norms.py"], + srcs_version = "PY2AND3", + deps = [ + ":common", + ":input_to_ops", + "//tensorflow/contrib/graph_editor:graph_editor_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + ], +) + +py_test( + name = "fold_batch_norms_test", + srcs = ["python/fold_batch_norms_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":fold_batch_norms", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + +py_library( + name = "copy_graph", + srcs = ["python/copy_graph.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + ], +) + +py_test( + name = "copy_graph_test", + size = "small", + srcs = ["python/copy_graph_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":copy_graph", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + +py_library( + name = "quant_ops", + srcs = ["python/quant_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + ], +) + +py_library( + name = "quantize", + srcs = ["python/quantize.py"], + srcs_version = "PY2AND3", + deps = [ + ":common", + ":input_to_ops", + ":quant_ops", + "//tensorflow/contrib/graph_editor:graph_editor_py", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + ], +) + +py_test( + name = "quantize_test", + size = "small", + srcs = ["python/quantize_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":quantize", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + +py_test( + name = "quantize_parameterized_test", + size = "medium", + srcs = ["python/quantize_parameterized_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":quantize", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + +py_library( + name = "quantize_graph", + srcs = [ + "__init__.py", + "python/quantize_graph.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":copy_graph", + ":fold_batch_norms", + ":quantize", + "//tensorflow/python:framework_ops", + "//tensorflow/python:variables", + ], +) + +py_test( + name = "quantize_graph_test", + size = "small", + srcs = ["python/quantize_graph_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":quantize_graph", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/quantize/__init__.py b/tensorflow/contrib/quantize/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d4e4575c935e0a888c6e5e4d0db640d93e1bd49 --- /dev/null +++ b/tensorflow/contrib/quantize/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions for rewriting graphs for quantized training.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import,line-too-long +from tensorflow.contrib.quantize.python.quantize_graph import * +# pylint: enable=unused-import,wildcard-import,line-too-long + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "create_eval_graph", + "create_training_graph", +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d0b0674c31239ee903f5ab7ef9ae0262bb20d189 --- /dev/null +++ b/tensorflow/contrib/quantize/python/common.py @@ -0,0 +1,88 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constants used across this package.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re + +# Skip all operations that are backprop related or export summaries. +SKIPPED_PREFIXES = ( + 'gradients/', 'RMSProp/', 'Adagrad/', 'Const_', 'HistogramSummary', + 'ScalarSummary') + +# Valid activation ops for quantization end points. +_ACTIVATION_OP_SUFFIXES = ['/Relu6', '/Relu', '/Identity'] + +# Regular expression for recognizing nodes that are part of batch norm group. +_BATCHNORM_RE = re.compile(r'^(.*)/BatchNorm/batchnorm') + + +def BatchNormGroups(graph): + """Finds batch norm layers, returns their prefixes as a list of strings. + + Args: + graph: Graph to inspect. + + Returns: + List of strings, prefixes of batch norm group names found. + """ + bns = [] + for op in graph.get_operations(): + match = _BATCHNORM_RE.search(op.name) + if match: + bn = match.group(1) + if not bn.startswith(SKIPPED_PREFIXES): + bns.append(bn) + # Filter out duplicates. + return list(collections.OrderedDict.fromkeys(bns)) + + +def GetEndpointActivationOp(graph, prefix): + """Returns an Operation with the given prefix and a valid end point suffix. + + Args: + graph: Graph where to look for the operation. + prefix: String, prefix of Operation to return. + + Returns: + The Operation with the given prefix and a valid end point suffix or None if + there are no matching operations in the graph for any valid suffix + """ + for suffix in _ACTIVATION_OP_SUFFIXES: + activation = _GetOperationByNameDontThrow(graph, prefix + suffix) + if activation: + return activation + return None + + +def _GetOperationByNameDontThrow(graph, name): + """Returns an Operation with the given name. + + Args: + graph: Graph where to look for the operation. + name: String, name of Operation to return. + + Returns: + The Operation with the given name. None if the name does not correspond to + any operation in the graph + """ + try: + return graph.get_operation_by_name(name) + except KeyError: + return None diff --git a/tensorflow/contrib/quantize/python/copy_graph.py b/tensorflow/contrib/quantize/python/copy_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..0376fcba82b99feabdba3b683f9db9a32db51efb --- /dev/null +++ b/tensorflow/contrib/quantize/python/copy_graph.py @@ -0,0 +1,32 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to copy a tf.Graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.training import saver as saver_lib + + +def CopyGraph(graph): + """Return a copy of graph.""" + meta_graph = saver_lib.export_meta_graph( + graph=graph, collection_list=graph.get_all_collection_keys()) + graph_copy = ops.Graph() + with graph_copy.as_default(): + _ = saver_lib.import_meta_graph(meta_graph) + return graph_copy diff --git a/tensorflow/contrib/quantize/python/copy_graph_test.py b/tensorflow/contrib/quantize/python/copy_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0889f12de6aac53f70ecfa7b70fc19ac7b95a5fe --- /dev/null +++ b/tensorflow/contrib/quantize/python/copy_graph_test.py @@ -0,0 +1,55 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.quantized.mangle.copy_graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.quantize.python import copy_graph +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +class CopyGraphTest(test_util.TensorFlowTestCase): + + def _CompareNodeInGraph(self, node, graph): + graph_node = graph.get_operation_by_name(node.name) + self.assertEqual(str(node.node_def), str(graph_node.node_def)) + + def testCopyGraph(self): + graph = ops.Graph() + with graph.as_default(): + a = constant_op.constant(1.0) + b = variables.Variable(2.0) + c = a + b + graph_copy = copy_graph.CopyGraph(graph) + # Ensure that the three original nodes are in the new graph. + # import_meta_graph also adds a saver node to the graph which we don't care + # about in this specific use case. + for tensor in [a, b, c]: + self._CompareNodeInGraph(tensor.op, graph_copy) + # Test that the graph collections are the same. + for key in graph.get_all_collection_keys(): + self.assertEqual( + len(graph.get_collection(key)), + len(graph_copy.get_collection(key)), 'Collection %s differs.') + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py new file mode 100644 index 0000000000000000000000000000000000000000..c4166895108294148fd09ed95e6227fda17ef36f --- /dev/null +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -0,0 +1,309 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Logic to fold batch norm into preceding convolution or FC layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +from tensorflow.contrib import graph_editor +from tensorflow.contrib.quantize.python import common +from tensorflow.contrib.quantize.python import input_to_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops + + +def FoldBatchNorms(graph): + """Finds batch norm layers in the graph, folds them into preceding layers. + + Folding only affects the following layers: Conv2D, fully connected, depthwise + convolution. + + Args: + graph: Graph to walk and modify. + + Raises: + ValueError: When batch norm folding fails. + """ + # Fail immediately when the graph contains unsupported fused batch norm ops. + if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'): + raise ValueError('Fused batch norm is not supported') + + input_to_ops_map = input_to_ops.InputToOps(graph) + + for bn in common.BatchNormGroups(graph): + has_scaling = _HasScaling(graph, input_to_ops_map, bn) + + # The mangling code intimately depends on BatchNorm node's internals. + original_op, folded_op = _CreateFoldedOp(graph, bn, has_scaling=has_scaling) + + activation = common.GetEndpointActivationOp(graph, bn) + if activation: + nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], + [original_op.outputs[0]], + can_modify=[activation]) + if nodes_modified_count != 1: + raise ValueError('Unexpected inputs to op: %s' % activation.name) + continue + + # Treat consumer ops in bypass modules differently since they have Add + # operations instead of Relu* above. + add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) + add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') + nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], + [original_op.outputs[0]], + can_modify=[add_bypass]) + if nodes_modified_count != 1: + raise ValueError('Unexpected inputs to op: %s' % add_bypass.name) + + +def _HasScaling(graph, input_to_ops_map, bn): + r"""Checks if batch norm has scaling enabled. + + Difference between batch norm with scaling and without is that with scaling: + + Rsqrt -> mul -> mul_1 + \-> mul_2 + + where + mul multiplies gamma by inverse square root of EMA of batch variance, + mul_1 multiplies output of mul with output from the base operation + (convolution, FC or depthwise convolution), + mul_2 multiplies output of mul with EMA of batch mean, + and without scaling: + + Rsqrt -> mul + \-> mul_1 + + where + mul multiplies the inverse square root of EMA of batch variance with output + from the base operation, + mul_1 multiplies inverse square root of EMA of batch variance with EMA + of batch mean. + + Args: + graph: Graph to inspect. + input_to_ops_map: InputToOps object containing mapping from tensor's name + to ops that take it as input. + bn: Batch norm layer prefix string. + + Returns: + A boolean indicating whether this batch norm layer has scaling enabled. + """ + rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt') + rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) + + return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 + + +def _CreateFoldedOp(graph, context, has_scaling): + """Folds in batch norm layer into preceding convolution or FC layer. + + Creates 3 new nodes, connects their inputs and adds them to the graph: + mul is cloned into mul_fold, Conv2D or MatMul, or DepthwiseConv2d is cloned + into respective *_Fold, add is cloned into add_fold. + + Args: + graph: Graph to modify. + context: String, batch norm context, i.e. node into which BatchNorm is + nested. + has_scaling: Whether the batch norm has scaling enabled. + + Raises: + ValueError: When operation type is not supported, or input and output tensor + shapes mismatch for created operations: mul_fold, add_fold. + + Returns: + A pair of Operations, the first is the original consumer node of the batch + norm (../BatchNorm/batchnorm/add_1), the second is the consumer node of + the folded graph (add_fold). + """ + mul_scale_name = 'mul_1' if has_scaling else 'mul' + mul_scale = graph.get_operation_by_name(context + + '/BatchNorm/batchnorm/' + + mul_scale_name) + op_below = mul_scale.inputs[0].op + weights = op_below.inputs[1] + + # Special handling for weights of depthwise convolution. + if op_below.type == 'DepthwiseConv2dNative': + new_shape = [weights.get_shape().as_list()[2], + weights.get_shape().as_list()[3]] + scale_name = 'mul' if has_scaling else 'Rsqrt' + scale = graph.get_operation_by_name(context + '/BatchNorm/batchnorm/' + + scale_name) + scale = array_ops.reshape(scale.outputs[0], new_shape, + context + '/scale_reshape') + mul_fold = _CloneOp(mul_scale, context + '/mul_fold', + [(0, weights), (1, scale)]) + elif op_below.type in ['Conv2D', 'MatMul']: + mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)]) + else: + raise ValueError('Cannot handle operation of type: %s' % op_below.op) + _AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0]) + + conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold', + [(1, mul_fold.outputs[0])]) + + add_shift = graph.get_operation_by_name(context + + '/BatchNorm/batchnorm/add_1') + add_fold = _CloneOp(add_shift, context + '/add_fold', + [(0, conv_or_fc_folded.outputs[0])]) + _AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0]) + return add_shift, add_fold + + +def _CloneOp(op, new_name, new_inputs): + """Clones a given op, replaces its name and some of its inputs. + + Args: + op: Operation to modify. + new_name: String, a new name to set on cloned op. + new_inputs: A list of tuples (idx, tensor), each input with corresponding + index will be replaced by the given Tensor in the cloned op. + + Returns: + Operation, the cloned op. + + Raises: + TypeError: When Operation type is not supported. + ValueError: When input shapes are incompatible. + """ + inputs = list(op.inputs) + for new_input in new_inputs: + inputs[new_input[0]] = new_input[1] + return _OP_CLONER.Clone(op, inputs, new_name) + + +class _OpCloner(object): + """Helper class that clones tf.Operations based on their type.""" + + def __init__(self): + self.op_type_to_action = { + 'Mul': self._CloneMul, + 'Add': self._CloneAdd, + 'Conv2D': self._CloneConv2d, + 'DepthwiseConv2dNative': self._CloneDepthwiseConv2d, + 'MatMul': self._CloneMatMul, + } + + def _CloneMul(self, op, inputs, new_name): + del op # Unused. + return math_ops.multiply(inputs[0], inputs[1], name=new_name).op + + def _CloneAdd(self, op, inputs, new_name): + del op # Unused. + return math_ops.add(inputs[0], inputs[1], name=new_name).op + + def _CloneConv2d(self, op, inputs, new_name): + input_tensor = inputs[0] + weights = inputs[1] + self._AssertConvShapes(op.name, input_tensor, weights) + return nn_ops.conv2d( + input_tensor, + weights, + strides=op.get_attr('strides'), + padding=op.get_attr('padding'), + use_cudnn_on_gpu=op.get_attr('use_cudnn_on_gpu'), + data_format=op.get_attr('data_format'), + name=new_name).op + + def _CloneDepthwiseConv2d(self, op, inputs, new_name): + input_tensor = inputs[0] + weights = inputs[1] + self._AssertConvShapes(op.name, input_tensor, weights) + return nn.depthwise_conv2d( + input_tensor, + weights, + strides=op.get_attr('strides'), + padding=op.get_attr('padding'), + name=new_name).op + + def _CloneMatMul(self, op, inputs, new_name): + weights = inputs[0] + input_tensor = inputs[1] + self._AssertFCShapes(op.name, weights, input_tensor) + return math_ops.matmul( + weights, + input_tensor, + transpose_a=op.get_attr('transpose_a'), + transpose_b=op.get_attr('transpose_b'), + name=new_name).op + + def Clone(self, op, inputs, new_name): + try: + return self.op_type_to_action[op.type](op, inputs, new_name) + except KeyError: + raise TypeError('Unsupported operation type: %s' % op.type) + + def _AssertConvShapes(self, op_name, input_tensor, weights): + """Makes sure that convolution inputs have compatible shapes. + + Args: + op_name: Operation name, only used in error message. + input_tensor: Input that is convolved. + weights: Weights of the convolution filter. + + Raises: + ValueError: When input shapes are incompatible. + """ + input_shape = input_tensor.get_shape() + weights_shape = weights.get_shape() + if (len(input_shape) != 4 or len(weights_shape) != 4 or + input_shape[3] != weights_shape[2]): + raise ValueError('Incompatible shapes for op %s inputs: %s and %s' % + (op_name, input_shape, weights_shape)) + + def _AssertFCShapes(self, op_name, weights, input_tensor): + """Makes sure that FC layer inputs have compatible shapes. + + Args: + op_name: Operation name, only used in error message. + weights: Weights used in FC layer. + input_tensor: Input into FC layer. + + Raises: + ValueError: When input shapes are incompatible. + """ + weights_shape = weights.get_shape() + input_shape = input_tensor.get_shape() + if (len(weights_shape) != 2 or len(input_shape) != 2 or + weights_shape[1] != input_shape[0]): + raise ValueError('Incompatible shapes for op %s inputs: %s and %s' % + (op_name, weights_shape, input_shape)) + +_OP_CLONER = _OpCloner() + + +def _AssertShapesMatch(op_name, in_tensor, out_tensor): + """Makes sure that shapes of input and output tensors are compatible. + + Args: + op_name: String, operation name, only used in error message. + in_tensor: Tensor, input tensor. + out_tensor: Tensor, output tensor. + + Raises: + ValueError: When input and output tensors have different shapes. + """ + in_shape = in_tensor.get_shape() + out_shape = out_tensor.get_shape() + + if not in_shape.is_compatible_with(out_shape): + raise ValueError('%s should not change tensor shape: input %s, ' + 'output %s' % (op_name, in_shape, out_shape)) diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ddedb0a2c067a27d05dc1aff4c2b4c447dafe93a --- /dev/null +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -0,0 +1,521 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for folding batch norm layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import fold_batch_norms +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + +batch_norm = layers.batch_norm +conv2d = layers.conv2d +fully_connected = layers.fully_connected +separable_conv2d = layers.separable_conv2d + +_DEFAULT_BATCH_NORM_PARAMS = { + 'center': True, + 'scale': True, + 'decay': 1.0 - 0.003, + 'fused': False, +} + + +# TODO(suharshs): Use parameterized test once OSS TF supports it. +class FoldBatchNormsTest(test_util.TensorFlowTestCase): + + def _RunTestOverParameters(self, test_fn): + parameters_list = [ + # (relu, relu_op_name, with_bypass) + (nn_ops.relu6, 'Relu6', False), + (nn_ops.relu, 'Relu', False), + (nn_ops.relu6, 'Relu6', True), + (nn_ops.relu, 'Relu', True), + ] + for parameters in parameters_list: + test_fn(parameters[0], parameters[1], parameters[2]) + + def testFailsWithFusedBatchNorm(self): + self._RunTestOverParameters(self._TestFailsWithFusedBatchNorm) + + def _TestFailsWithFusedBatchNorm(self, relu, relu_op_name, with_bypass): + """Tests that batch norm fails when fused batch norm ops are present.""" + g = ops.Graph() + with g.as_default(): + batch_size, height, width = 5, 128, 128 + inputs = array_ops.zeros((batch_size, height, width, 3)) + out_depth = 3 if with_bypass else 32 + stride = 1 if with_bypass else 2 + activation_fn = None if with_bypass else relu + batch_norm_params = _DEFAULT_BATCH_NORM_PARAMS.copy() + batch_norm_params['fused'] = True + scope = 'test/test2' if with_bypass else 'test' + node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=batch_norm_params, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + with self.assertRaises(ValueError): + fold_batch_norms.FoldBatchNorms(g) + + def _TestFoldConv2d(self, relu, relu_op_name, with_bypass): + """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + """ + g = ops.Graph() + with g.as_default(): + batch_size, height, width = 5, 128, 128 + inputs = array_ops.zeros((batch_size, height, width, 3)) + out_depth = 3 if with_bypass else 32 + stride = 1 if with_bypass else 2 + activation_fn = None if with_bypass else relu + scope = 'test/test2' if with_bypass else 'test' + node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms(g) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + self._AssertInputOpsAre(folded_mul, + [scope + '/weights/read', + scope + '/BatchNorm/batchnorm/mul']) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) + + folded_conv = g.get_operation_by_name(scope + '/convolution_Fold') + self.assertEqual(folded_conv.type, 'Conv2D') + self._AssertInputOpsAre(folded_conv, + [scope + '/mul_fold', inputs.op.name]) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, + [scope + '/convolution_Fold', + scope + '/BatchNorm/batchnorm/sub']) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + def testFoldConv2d(self): + self._RunTestOverParameters(self._TestFoldConv2d) + + def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass): + """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. + + Tests that folding works even with an input shape where some dimensions are + not known (i.e. None). + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + """ + g = ops.Graph() + with g.as_default(): + inputs = array_ops.placeholder(dtypes.float32, shape=(5, None, None, 3)) + out_depth = 3 if with_bypass else 32 + stride = 1 if with_bypass else 2 + activation_fn = None if with_bypass else relu + scope = 'test/test2' if with_bypass else 'test' + node = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms(g) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + self._AssertInputOpsAre(folded_mul, [ + scope + '/weights/read', scope + '/BatchNorm/batchnorm/mul' + ]) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) + + folded_conv = g.get_operation_by_name(scope + '/convolution_Fold') + self.assertEqual(folded_conv.type, 'Conv2D') + self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, [ + scope + '/convolution_Fold', scope + '/BatchNorm/batchnorm/sub' + ]) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + def testFoldConv2dUnknownShape(self): + self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) + + def _TestFoldConv2dWithoutScale(self, relu, relu_op_name, with_bypass): + """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + """ + g = ops.Graph() + with g.as_default(): + batch_size, height, width = 5, 128, 128 + inputs = array_ops.zeros((batch_size, height, width, 3)) + out_depth = 3 if with_bypass else 32 + stride = 1 if with_bypass else 2 + activation_fn = None if with_bypass else relu + bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) + bn_params['scale'] = False + scope = 'test/test2' if with_bypass else 'test' + node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=bn_params, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms(g) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + self._AssertInputOpsAre(folded_mul, + [scope + '/weights/read', + scope + '/BatchNorm/batchnorm/Rsqrt']) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) + + folded_conv = g.get_operation_by_name(scope + '/convolution_Fold') + self.assertEqual(folded_conv.type, 'Conv2D') + self._AssertInputOpsAre(folded_conv, + [scope + '/mul_fold', inputs.op.name]) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, + [scope + '/convolution_Fold', + scope + '/BatchNorm/batchnorm/sub']) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + def testFoldConv2dWithoutScale(self): + self._RunTestOverParameters(self._TestFoldConv2dWithoutScale) + + def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass): + """Tests folding cases: inputs -> FC with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + """ + g = ops.Graph() + with g.as_default(): + batch_size, depth = 5, 256 + inputs = array_ops.zeros((batch_size, depth)) + out_depth = 256 if with_bypass else 128 + activation_fn = None if with_bypass else relu + scope = 'test/test2' if with_bypass else 'test' + node = fully_connected(inputs, out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms(g) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + self._AssertInputOpsAre(folded_mul, + [scope + '/weights/read', + scope + '/BatchNorm/batchnorm/mul']) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) + + folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold') + self.assertEqual(folded_conv.type, 'MatMul') + self._AssertInputOpsAre(folded_conv, + [scope + '/mul_fold', inputs.op.name]) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, + [scope + '/MatMul_Fold', + scope + '/BatchNorm/batchnorm/sub']) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + def testFoldFullyConnectedLayer(self): + self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) + + def _TestFoldFullyConnectedLayerWithoutScale(self, relu, relu_op_name, + with_bypass): + """Tests folding cases: inputs -> FC with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + """ + g = ops.Graph() + with g.as_default(): + batch_size, depth = 5, 256 + inputs = array_ops.zeros((batch_size, depth)) + out_depth = 256 if with_bypass else 128 + activation_fn = None if with_bypass else relu + bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) + bn_params['scale'] = False + scope = 'test/test2' if with_bypass else 'test' + node = fully_connected(inputs, out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=bn_params, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms(g) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + self._AssertInputOpsAre(folded_mul, + [scope + '/weights/read', + scope + '/BatchNorm/batchnorm/Rsqrt']) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) + + folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold') + self.assertEqual(folded_conv.type, 'MatMul') + self._AssertInputOpsAre(folded_conv, + [scope + '/mul_fold', inputs.op.name]) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, + [scope + '/MatMul_Fold', + scope + '/BatchNorm/batchnorm/sub']) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + def testFoldFullyConnectedLayerWithoutScale(self): + self._RunTestOverParameters(self._TestFoldFullyConnectedLayerWithoutScale) + + def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass): + """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + """ + g = ops.Graph() + with g.as_default(): + batch_size, height, width = 5, 128, 128 + inputs = array_ops.zeros((batch_size, height, width, 3)) + stride = 1 if with_bypass else 2 + activation_fn = None if with_bypass else relu + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d(inputs, None, [5, 5], stride=stride, + depth_multiplier=1.0, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms(g) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + self._AssertInputOpsAre(folded_mul, + [scope + '/depthwise_weights/read', + scope + '/scale_reshape']) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold']) + + scale_reshape = g.get_operation_by_name(scope + '/scale_reshape') + self.assertEqual(scale_reshape.type, 'Reshape') + self._AssertInputOpsAre(scale_reshape, + [scope + '/BatchNorm/batchnorm/mul', + scope + '/scale_reshape/shape']) + self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) + + folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') + self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') + self._AssertInputOpsAre(folded_conv, + [scope + '/mul_fold', inputs.op.name]) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, + [scope + '/depthwise_Fold', + scope + '/BatchNorm/batchnorm/sub']) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + def testFoldDepthwiseConv2d(self): + self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) + + def _TestFoldDepthwiseConv2dWithoutScale(self, relu, relu_op_name, + with_bypass): + """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + """ + g = ops.Graph() + with g.as_default(): + batch_size, height, width = 5, 128, 128 + inputs = array_ops.zeros((batch_size, height, width, 3)) + stride = 1 if with_bypass else 2 + activation_fn = None if with_bypass else relu + bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) + bn_params['scale'] = False + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d(inputs, None, [5, 5], stride=stride, + depth_multiplier=1.0, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=bn_params, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms(g) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + self._AssertInputOpsAre(folded_mul, + [scope + '/depthwise_weights/read', + scope + '/scale_reshape']) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold']) + + scale_reshape = g.get_operation_by_name(scope + '/scale_reshape') + self.assertEqual(scale_reshape.type, 'Reshape') + self._AssertInputOpsAre(scale_reshape, + [scope + '/BatchNorm/batchnorm/Rsqrt', + scope + '/scale_reshape/shape']) + self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) + + folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') + self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') + self._AssertInputOpsAre(folded_conv, + [scope + '/mul_fold', inputs.op.name]) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, + [scope + '/depthwise_Fold', + scope + '/BatchNorm/batchnorm/sub']) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + def testFoldDepthwiseConv2dWithoutScale(self): + self._RunTestOverParameters(self._TestFoldDepthwiseConv2dWithoutScale) + + def _WeightInit(self, stddev): + """Returns a truncated normal variable initializer. + + Function is defined purely to shorten the name so that it stops wrapping. + + Args: + stddev: Standard deviation of normal variable. + + Returns: + An initializer that initializes with a truncated normal variable. + """ + return init_ops.truncated_normal_initializer(stddev=stddev) + + def _AssertInputOpsAre(self, op, in_op_names): + """Asserts that all inputs to op come from in_op_names (disregarding order). + + Args: + op: Operation to check inputs for. + in_op_names: List of strings, operations where all op's inputs should + come from. + """ + expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names] + self.assertItemsEqual([t.name for t in op.inputs], expected_inputs) + + def _AssertOutputGoesToOps(self, op, graph, out_op_names): + """Asserts that outputs from op go to out_op_names (and perhaps others). + + Args: + op: Operation to check outputs for. + graph: Graph where output operations are located. + out_op_names: List of strings, operations where op's outputs should go. + """ + for out_op_name in out_op_names: + out_op = graph.get_operation_by_name(out_op_name) + self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/input_to_ops.py b/tensorflow/contrib/quantize/python/input_to_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..98755607771ff725023fdf1abbcad8e95e851e23 --- /dev/null +++ b/tensorflow/contrib/quantize/python/input_to_ops.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Logic to update a Tensorflow model graph with quantization operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from tensorflow.contrib.quantize.python import common + + +class InputToOps(object): + """Holds a mapping from tensor's name to ops that take it as input.""" + + def __init__(self, graph): + """Initializes mapping from tensor's name to ops that take it. + + Helps find edges between ops faster and avoids iterating over the whole + graph. The mapping is of type Dict[str, Set[tf.Operation]]. + + Note: while inserting operations into the graph, we do not update the + mapping, assuming that insertion points in the graph are never adjacent. + With that restriction, an out of date mapping still works fine. + + Args: + graph: Graph to process. + """ + self.mapping = collections.defaultdict(set) + for op in (op for op in graph.get_operations()): + if op.name.startswith(common.SKIPPED_PREFIXES): + continue + for op_input in op.inputs: + self.mapping[op_input].add(op) + + def ConsumerOperations(self, producer_op): + """Looks through outputs of producer_op, finds ops that take them as input. + + Args: + producer_op: Operation containing outputs to process. + + Returns: + A Set[Operation] containing all operations taking input from producer_op + outputs. + """ + result = set() + for inp in producer_op.outputs: + result.update(self.mapping[inp]) + return result diff --git a/tensorflow/contrib/quantize/python/input_to_ops_test.py b/tensorflow/contrib/quantize/python/input_to_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9dbd1eb711831558b94a2c5793311d5c3e85963e --- /dev/null +++ b/tensorflow/contrib/quantize/python/input_to_ops_test.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for InputToOps class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.quantize.python import input_to_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +class InputToOpsTest(test_util.TensorFlowTestCase): + + def testNoConsumerOperations(self): + graph = ops.Graph() + with graph.as_default(): + input_tensor = array_ops.zeros((1, 2, 3, 4)) + + input_to_ops_map = input_to_ops.InputToOps(graph) + consumer_operations = input_to_ops_map.ConsumerOperations(input_tensor.op) + + self.assertEqual(0, len(consumer_operations)) + + def testOneConsumerOperation(self): + graph = ops.Graph() + with graph.as_default(): + input_tensor = array_ops.zeros((1, 2, 3, 4)) + output_tensor = nn_ops.relu6(input_tensor) + + input_to_ops_map = input_to_ops.InputToOps(graph) + consumer_operations = input_to_ops_map.ConsumerOperations(input_tensor.op) + + self.assertEqual(consumer_operations, {output_tensor.op}) + + def testSeveralConsumerOperations(self): + graph = ops.Graph() + with graph.as_default(): + input_tensor = array_ops.zeros((1, 2, 3, 4)) + output_tensor_1 = nn_ops.relu6(input_tensor) + output_tensor_2 = input_tensor + output_tensor_1 + output_tensor_3 = input_tensor * output_tensor_2 + + input_to_ops_map = input_to_ops.InputToOps(graph) + consumer_operations = input_to_ops_map.ConsumerOperations(input_tensor.op) + + self.assertEqual(consumer_operations, + {output_tensor_1.op, output_tensor_2.op, + output_tensor_3.op}) + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0a38ef9fcd6f1699b0feee6d439ba69413e0899b --- /dev/null +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -0,0 +1,320 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python support for quantization operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.ops import add_arg_scope +from tensorflow.contrib.framework.python.ops import model_variable +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import moving_averages + +EPSILON = 1e-5 + + +@add_arg_scope +def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None): + """Adds a fake quantize layer with fixed quantization interval. + + Args: + inputs: a tensor containing values to be quantized. + init_min: the lower end of quantization interval. + init_max: the upper end of quantization interval. + scope: Optional scope for name_scope. + Returns: + a tensor containing quantized values. + """ + with ops.name_scope(scope, 'FixedQuantize', values=[inputs]): + return array_ops.fake_quant_with_min_max_args( + inputs, min=init_min, max=init_max) + + +@add_arg_scope +def LastValueQuantize(inputs, + per_channel=False, + init_min=-6.0, + init_max=6.0, + updates_collection=ops.GraphKeys.UPDATE_OPS, + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + scope=None, + reuse=None, + is_training=True, + num_bits=8, + narrow_range=False): + """Adds a layer that collects quantization ranges as last input ranges. + + LastValueQuantize creates variables called 'min' and 'max', representing the + interval used for quantization and clamping. + + Args: + inputs: a tensor containing values to be quantized. + per_channel: (Optional) a boolean specifying whether to use different + quantization ranges per output channel. + init_min: a float scalar, the initial value for variable min. + init_max: a float scalar, the initial value for variable max. + updates_collection: (Optional) collections to collect the update ops for + computation. + vars_collection: (Optional) collection where to store variables for + quantization interval ends. + scope: Optional scope for variable_scope. + reuse: whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + is_training: Whether the op is applied to a training or eval graph. + num_bits: Number of bits to use for quantization, must be between 2 and 8. + narrow_range: Whether to use the narrow quantization range + [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. + Returns: + a tensor containing quantized values. + """ + with variable_scope.variable_scope( + scope, 'LastValueQuantize', values=[inputs], reuse=reuse): + input_shape = inputs.get_shape() + input_dim = len(input_shape) + if per_channel: + # Only support quantizing 1-, 2- and 4-dimensional tensors. + assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' + ' scope: %s' % (input_shape, scope)) + min_max_shape = [input_shape[-1]] + else: + min_max_shape = [] + + min_var = model_variable( + 'min', + shape=min_max_shape, + initializer=init_ops.constant_initializer(init_min), + collections=[vars_collection], + trainable=False) + max_var = model_variable( + 'max', + shape=min_max_shape, + initializer=init_ops.constant_initializer(init_max), + collections=[vars_collection], + trainable=False) + if not is_training: + return _FakeQuantWithMinMaxVars( + inputs, + min_var, + max_var, + per_channel=per_channel, + num_bits=num_bits, + narrow_range=narrow_range) + + if per_channel: + if input_dim == 2: + reduce_dims = [0] + elif input_dim == 4: + reduce_dims = [0, 1, 2] + + if per_channel: + if input_dim >= 2: + batch_min = math_ops.reduce_min( + inputs, reduction_indices=reduce_dims, name='BatchMin') + else: + batch_min = inputs + else: + batch_min = math_ops.reduce_min(inputs, name='BatchMin') + batch_min -= EPSILON + # B-eng requires that 0.0 if always in the [min; max] range. + batch_min = math_ops.minimum(batch_min, 0.0) + assign_min_op = state_ops.assign( + min_var, batch_min, name='AssignMinLast').op + ops.add_to_collection(updates_collection, assign_min_op) + + if per_channel: + if input_dim >= 2: + batch_max = math_ops.reduce_max( + inputs, reduction_indices=reduce_dims, name='BatchMax') + else: + batch_max = inputs + else: + batch_max = math_ops.reduce_max(inputs, name='BatchMax') + batch_max += EPSILON + # B-eng requires that 0.0 if always in the [min; max] range. + batch_max = math_ops.maximum(batch_max, 0.0) + assign_max_op = state_ops.assign( + max_var, batch_max, name='AssignMaxLast').op + ops.add_to_collection(updates_collection, assign_max_op) + + return _FakeQuantWithMinMaxVars( + inputs, + batch_min, + batch_max, + per_channel=per_channel, + num_bits=num_bits, + narrow_range=narrow_range) + + +@add_arg_scope +def MovingAvgQuantize(inputs, + per_channel=False, + init_min=-6.0, + init_max=6.0, + ema_decay=0.999, + updates_collection=ops.GraphKeys.UPDATE_OPS, + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + scope=None, + reuse=None, + is_training=True, + num_bits=8, + narrow_range=False): + """Adds a layer that collects quantization ranges as EMAs of input ranges. + + MovingAvgQuantize creates variables called 'min' and 'max', representing the + interval used for quantization and clamping. + + Args: + inputs: a tensor containing values to be quantized. + per_channel: (default False) a boolean specifying whether to use different + quantization ranges per output channel. + init_min: a float scalar, the initial value for variable min. + init_max: a float scalar, the initial value for variable max. + ema_decay: EMA decay parameter. + updates_collection: (Optional) collections to collect the update ops for + computation. + vars_collection: (Optional) collection where to store variables for + quantization interval ends. + scope: Optional scope for variable_scope. + reuse: whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + is_training: Whether the op is applied to a training or eval graph. + num_bits: Number of bits to use for quantization, must be between 2 and 8. + narrow_range: Whether to use the narrow quantization range + [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. + Returns: + a tensor containing quantized values. + """ + with variable_scope.variable_scope( + scope, 'MovingAvgQuantize', values=[inputs], reuse=reuse): + input_shape = inputs.get_shape() + input_dim = len(input_shape) + if per_channel: + # Only support quantizing 1-, 2- and 4-dimensional tensors. + assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' + ' scope: %s' % (input_shape, scope)) + min_max_shape = [input_shape[-1]] + else: + min_max_shape = [] + + min_var = model_variable( + 'min', + shape=min_max_shape, + initializer=init_ops.constant_initializer(init_min), + collections=[vars_collection], + trainable=False) + max_var = model_variable( + 'max', + shape=min_max_shape, + initializer=init_ops.constant_initializer(init_max), + collections=[vars_collection], + trainable=False) + if not is_training: + return _FakeQuantWithMinMaxVars( + inputs, + min_var, + max_var, + per_channel=per_channel, + num_bits=num_bits, + narrow_range=narrow_range) + if per_channel: + if input_dim == 2: + reduce_dims = [0] + elif input_dim == 4: + reduce_dims = [0, 1, 2] + + if per_channel: + if input_dim >= 2: + batch_min = math_ops.reduce_min( + inputs, reduction_indices=reduce_dims, name='BatchMin') + else: + batch_min = inputs + else: + batch_min = math_ops.reduce_min(inputs, name='BatchMin') + # B-eng requires that 0.0 if always in the [min; max] range. + batch_min = math_ops.minimum(batch_min, 0.0) + assign_min_op = moving_averages.assign_moving_average( + min_var, batch_min, ema_decay, name='AssignMinEma').op + ops.add_to_collection(updates_collection, assign_min_op) + + if per_channel: + if input_dim >= 2: + batch_max = math_ops.reduce_max( + inputs, reduction_indices=reduce_dims, name='BatchMax') + else: + batch_max = inputs + else: + batch_max = math_ops.reduce_max(inputs, name='BatchMax') + # B-eng requires that 0.0 if always in the [min; max] range. + batch_max = math_ops.maximum(batch_max, 0.0) + assign_max_op = moving_averages.assign_moving_average( + max_var, batch_max, ema_decay, name='AssignMaxEma').op + ops.add_to_collection(updates_collection, assign_max_op) + + return _FakeQuantWithMinMaxVars( + inputs, + min_var, + max_var, + per_channel=per_channel, + num_bits=num_bits, + narrow_range=narrow_range) + + +def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits, + narrow_range): + """Adds a fake quantization operation. + + Depending on value of per_channel, this operation may do global quantization + or per channel quantization. min_var and max_var should have corresponding + shapes: [1] when per_channel == False and [d] when per_channel == True. + + Args: + inputs: a tensor containing values to be quantized. + min_var: a variable containing quantization range lower end(s). + max_var: a variable containing quantization range lupper end(s). + per_channel: a boolean specifying whether to use per-channel quantizatioh. + num_bits: Number of bits to use for quantization, must be between 2 and 8. + narrow_range: Whether to use the narrow quantization range + [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. + Returns: + a tensor containing quantized values. + """ + + if per_channel: + assert len(min_var.get_shape()) == 1 + assert len(max_var.get_shape()) == 1 + with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]): + return array_ops.fake_quant_with_min_max_vars_per_channel( + inputs, + min_var, + max_var, + num_bits=num_bits, + narrow_range=narrow_range) + else: + assert min_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison + assert max_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison + with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]): + return array_ops.fake_quant_with_min_max_vars( + inputs, + min_var, + max_var, + num_bits=num_bits, + narrow_range=narrow_range) diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..3645d034cdb2b82af25c6c8674bf781976ffbf0f --- /dev/null +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -0,0 +1,364 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Logic to update a Tensorflow model graph with quantization operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +from tensorflow.contrib import graph_editor +from tensorflow.contrib.quantize.python import common +from tensorflow.contrib.quantize.python import input_to_ops +from tensorflow.contrib.quantize.python import quant_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import training_util + +# Operation types used to select oerations of interest. +_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'} + +# Custom key for storing and retrieving update ops used by quantizing nodes. +_UPDATE_QUANT_OPS = 'update_quant_ops' + + +def Quantize(graph, + weight_bits=8, + weight_narrow_range=False, + activation_bits=8, + ema_decay=0.999, + quant_delay=None, + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + is_training=True, + quantize_folded_weights_use_ema=False): + """Updates graph with quantization operations. + + Args: + graph: Graph to modify. + weight_bits: Number of bits to use for quantizing weights. + weight_narrow_range: Whether to use a more efficient narrow range for + weights quantization. With weight_narrow_range true, the range is + [1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1]. + activation_bits: Number of bits to use for quantizing activations. + ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update + quantization intervals for quantizing activations (see here about EMA: + https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). + quant_delay: (Optional, default None) Int, count of global steps for which + to delay quantization. This helps weights stabilize at the start of + training. + vars_collection: (Optional) Collection where to store the variables for + quantization interval ends. + is_training: (Optional) Whether quantizing training graph or eval graph. + quantize_folded_weights_use_ema: (Optional, default False) Whether to + quantize weights after batchnorm-folding with exponential average + quantization. + Raises: + ValueError: When quantization fails. + """ + context = _QuantizeContext(graph, weight_bits, weight_narrow_range, + activation_bits, ema_decay, quant_delay, + vars_collection, is_training, + quantize_folded_weights_use_ema) + + graph_ops = graph.get_operations() + + # Filter out backprop and summary related operations, leave only interesting + # op types. + def _IsInterestingOpWithWeights(op): + return (op.type in _QUANTIZABLE_TYPES and + not op.name.startswith(common.SKIPPED_PREFIXES)) + + for op in (op for op in graph_ops if _IsInterestingOpWithWeights(op)): + if op.name.endswith('/depthwise'): + # Separable convolution may consist of 2 convolution nodes. If so, + # skip .../depthwise and only quantize the top one. + separable_conv = context.GetOperationByNameDontThrow( + op.name[:-len('/depthwise')]) + if separable_conv and separable_conv.type == 'Conv2D': + continue + if not op.name.endswith('_Fold'): + folded_op = context.GetOperationByNameDontThrow(op.name + '_Fold') + # Do nothing if found, it will be quantized when it is iterated over. + if not folded_op: + context.QuantizeOpWithWeights(op, folded=False) + else: + context.QuantizeOpWithWeights(op, folded=True) + + # Once all quantization ops have been inserted in the graph, collect update + # ops for their variables and modify the TF Slim update barrier (see + # https://www.tensorflow.org/code/tensorflow/contrib/slim/python/slim/learning.py) + # to depend on them. + try: + update_barrier = graph.get_operation_by_name('update_barrier') + except KeyError: + # In evaluation graph, this barrier may not exist. + return None + update_quant_ops = graph.get_collection_ref(_UPDATE_QUANT_OPS) + graph_editor.add_control_inputs(update_barrier, update_quant_ops) + + +class _QuantizeContext(object): + """Context holds references needed for quantization.""" + + def __init__(self, + graph, + weight_bits, + weight_narrow_range, + activation_bits, + ema_decay=0.999, + quant_delay=None, + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + is_training=True, + quantize_folded_weights_use_ema=False): + """Initializes context to hold references needed for quantization. + + Args: + graph: Graph to modify. + weight_bits: Number of bits to use for quantizing weights. + weight_narrow_range: Whether to use a more efficient narrow range for + weights quantization. With weight_narrow_range true, the range is + [1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1]. + activation_bits: Number of bits to use for quantizing activations. + ema_decay: (Optional) Float, EMA decay parameter. + quant_delay: (Optional, default None) Int, count of global steps for which + to delay quantization. This helps weights stabilize at the start of + training. + vars_collection: (Optional) Collection where to store the variables for + quantization interval ends. + is_training: (Optional) Whether quantizing training or eval graph. + quantize_folded_weights_use_ema: (Optional, default False) Whether to + quantize weights after batchnorm-folding with exponential average + quantization. + """ + self.graph = graph + self.weight_bits = weight_bits + self.weight_narrow_range = weight_narrow_range + self.activation_bits = activation_bits + self.ema_decay = ema_decay + self.quant_delay = quant_delay + self.vars_collection = vars_collection + self.is_training = is_training + self.quantize_folded_weights_use_ema = quantize_folded_weights_use_ema + self.input_to_ops_map = input_to_ops.InputToOps(graph) + + def QuantizeOpWithWeights(self, op, folded): + """Quantizes around the specific operation with or without batch norm. + + Args: + op: Operation to quantize. + folded: Operation has been folded and needs special handling if True. + Raises: + ValueError: When quantization fails. + """ + # Op name component before the last slash will be used as context. + context = re.search(r'^(.*)/([^/]+)', op.name).group(1) + + # Quantize weights. + if folded: + producer_op = self.graph.get_operation_by_name(context + '/mul_fold') + else: + try: + input_idx = next(i for i, v in enumerate(op.inputs) + if '/weights/' in v.name or + '/depthwise_weights' in v.name) + except StopIteration: + raise ValueError('No inputs to quantize for op: %s' % op) + producer_op = op.inputs[input_idx].op + + # If batch norm is used, the folded weights depend on the batch std, hence + # it is sensible to use EMA during training to smooth out the noise. This is + # controlled by the flag quantize_folded_weights_use_ema. Its default is + # False for backward compatibility. + # If there is no batch norm, weights do not depend on the batch and using + # the latest value of min and max is more efficient. + weight_use_ema = folded and self.quantize_folded_weights_use_ema + self._InsertQuantOp( + context, + producer_op, [op], + name='weights_quant', + moving_avg=weight_use_ema, + delay_requested=weight_use_ema, + bits=self.weight_bits, + narrow_range=self.weight_narrow_range) + + # Important: do not quantize biases here. During inference they are + # quantized to 32 bits, which is much finer than 8 bit quantization and + # depends on weight and input activation ranges. + + # Find activation and (optionally) Add operations to quantize. + activation_op, add_op, add_context = self._GetReluAndAddOperations(context, + op) + if add_op: + original_context = context + context = add_context + + # Quantize activation outputs. + consumer_ops = self.input_to_ops_map.ConsumerOperations(activation_op) + self._InsertQuantOp( + context, + activation_op, + consumer_ops, + name='act_quant', + moving_avg=True, + init_min=0.0, + bits=self.activation_bits, + narrow_range=False) + + # When a bypass connection was found, also quantize Add op input. + if add_op: + + def _QuantizeAddInput(add_input): + if folded: + return add_input.op.name.endswith('/add_fold') + else: + return add_input.op.name.startswith(original_context + '/') + + for add_input in add_op.inputs: + if _QuantizeAddInput(add_input): + self._InsertQuantOp( + original_context, + add_input.op, [add_op], + name='conv_quant', + moving_avg=True, + bits=self.activation_bits, + narrow_range=False) + + def _GetReluAndAddOperations(self, context, op): + """Looks up a Relu* and Add operations in given context. + + Args: + context: Context where to look for operations. + op: Operation to quantize. + + Returns: + A triplet (Operation, Operation, string), the first element is an end + point operation, the second is Add operation (optional), the third element + is string context where the Add operation was found (optional). + + Raises: + ValueError: When operations cannot be found. + """ + activation_op = common.GetEndpointActivationOp(self.graph, context) + if activation_op: + return activation_op, None, None + + if '/' in context: + # If no activation op is there, look for them one level up. + add_context = re.search(r'^(.*)/([^/]+)', context).group(1) + activation_op = common.GetEndpointActivationOp(self.graph, add_context) + if not activation_op: + # Still no Relu, can happen on the top layer, just find the next node up, + # make sure it is BiasAdd. + consumers = [c for outp in op.outputs for c in outp.consumers()] + if len(consumers) != 1 or consumers[0].type != 'BiasAdd': + raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type)) + return consumers[0], None, None + if add_context: + add_op = self.GetOperationByNameDontThrow(add_context + '/Add') + return activation_op, add_op, add_context + else: + raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type)) + + def GetOperationByNameDontThrow(self, name): + """Returns an Operation with the given name. + + Args: + name: Name of Operation to return. + + Returns: + The Operation with the given name. None if the name does not correspond to + any operation in the graph + """ + try: + return self.graph.get_operation_by_name(name) + except KeyError: + return None + + def _InsertQuantOp( + self, + context, + producer, + consumers, + name, + moving_avg=True, + init_min=-6.0, + init_max=6.0, + delay_requested=True, + bits=8, + narrow_range=False,): + """Inserts a quant op between a producer op and (multiple) consumer ops. + + Args: + context: Context where producer and consumer operations are nested. + producer: Producer operation of the pairs where quantization will be + inserted. + consumers: Consumer operations of the pairs. + name: Name for the new quantization op within the context. + moving_avg: Specifies whether to use exponential moving average or just + the last value seen. + init_min: Starting minimum value for the new quantization op. + init_max: Starting maximum value for the new quantization op. + delay_requested: If true, implement quantization delay where needed. + False value explicitly disables delay quantization everywhere. + bits: Number of bits to use for quantization, must be between 2 and 8. + narrow_range: Whether to use the narrow quantization range + [1; 2^bits - 1] or wide range [0; 2^bits - 1]. + Raises: + ValueError: When producer operation is not directly connected to the + consumer operation. + """ + scope = context + '/' + name + inputs = producer.outputs[0] + if moving_avg: + quant = (quant_ops.MovingAvgQuantize( + inputs, + init_min=init_min, + init_max=init_max, + ema_decay=self.ema_decay, + is_training=self.is_training, + num_bits=bits, + narrow_range=narrow_range, + updates_collection=_UPDATE_QUANT_OPS, + vars_collection=self.vars_collection, + scope=scope)) + else: + quant = (quant_ops.LastValueQuantize( + inputs, + init_min=init_min, + init_max=init_max, + is_training=self.is_training, + num_bits=bits, + narrow_range=narrow_range, + updates_collection=_UPDATE_QUANT_OPS, + vars_collection=self.vars_collection, + scope=scope)) + + if delay_requested and self.quant_delay and self.quant_delay > 0: + activate_quant = math_ops.greater_equal( + training_util.get_global_step(), + self.quant_delay, + name=scope + '/activate_quant') + quant = control_flow_ops.cond( + activate_quant, + lambda: quant, + lambda: inputs, + name=scope + '/delayed_quant') + + nodes_modified_count = graph_editor.reroute_ts( + [quant], [inputs], can_modify=consumers) + if nodes_modified_count != len(consumers): + raise ValueError('Some inputs not quantized for ops: [%s]' % + ', '.join([consumer.name for consumer in consumers])) diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf3e92b8ea518fbbe55628b856e0191c949c619 --- /dev/null +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -0,0 +1,114 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""API to simulate quantization on a python graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.quantize.python import copy_graph +from tensorflow.contrib.quantize.python import fold_batch_norms +from tensorflow.contrib.quantize.python import quantize +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables + + +def _create_graph(input_graph, is_training, elements=None): + """Returns a transformed training input_graph for simulated quantization. + + The forward pass has fake quantization ops inserted to simulate the error + introduced by quantization. + + Args: + input_graph: The tf.Graph to be transformed. + is_training: Whether quantizing training or eval graph. + elements: (Optional) List of Tensors and Operations in input_graph whose + corresponding elements in the new graph will be returned. + + Returns: + Returns a tuple(g, l) where: + g is new tf.Graph that is rewritten for simulated quantization. + l is a list of Tensors/Operations in g corresponding to the provided input + elements. + + Raises: + ValueError: If elements contains an element that isn't a tf.Tensor or + tf.Operation. + """ + # TODO(suharshs): Describe the process in more detail in the doc string. + g = copy_graph.CopyGraph(input_graph) + fold_batch_norms.FoldBatchNorms(g) + quantize.Quantize(g, is_training=is_training) + return_elements = [] + if elements is None: + elements = [] + for element in elements: + if isinstance(element, (ops.Tensor, variables.Variable)): + return_elements.append(g.get_tensor_by_name(element.name)) + elif isinstance(element, ops.Operation): + return_elements.append(g.get_operation_by_name(element.name)) + else: + raise ValueError( + 'elements must consist of Tensor or Operation objects, got: ', + str(element)) + return g, return_elements + + +def create_training_graph(input_graph, elements=None): + """Returns a transformed training input_graph for simulated quantization. + + The forward pass has fake quantization ops inserted to simulate the error + introduced by quantization. + + Args: + input_graph: The tf.Graph to be transformed. + elements: (Optional) List of Tensors and Operations in input_graph whose + corresponding elements in the new graph will be returned. + + Returns: + Returns a tuple(g, l) where: + g is new tf.Graph that is rewritten for simulated quantization. + l is a list of Tensors/Operations in g corresponding to the provided input + elements. + + Raises: + ValueError: If elements contains an element that isn't a tf.Tensor or + tf.Operation. + """ + return _create_graph(input_graph, True, elements) + + +def create_eval_graph(input_graph, elements=None): + """Returns a transformed eval input_graph for simulated quantization. + + The forward pass has fake quantization ops inserted to simulate the error + introduced by quantization. + + Args: + input_graph: The tf.Graph to be transformed. + elements: (Optional) List of Tensors and Operations in input_graph whose + corresponding elements in the new graph will be returned. + + Returns: + Returns a tuple(g, l) where: + g is new tf.Graph that is rewritten for simulated quantization. + l is a list of Tensors/Operations in g corresponding to the provided input + elements. + + Raises: + ValueError: If elements contains an element that isn't a tf.Tensor or + tf.Operation. + """ + return _create_graph(input_graph, False, elements) diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..382076672a70c873ae7c1384e0706231a0ba8a55 --- /dev/null +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -0,0 +1,75 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for the quantize_graph graph rewriting API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.quantize.python import quantize_graph +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +class QuantizeTest(test_util.TensorFlowTestCase): + + # We have a lot of other tests that test the details of the rewrite, here we + # just the specific features of the quantize_graph API. + def testReturnedElementsTraining(self): + graph = ops.Graph() + with graph.as_default(): + a = constant_op.constant(1.0) + b = variables.Variable(2.0) + c = a + b + elements = [a, b, c.op] + for element in elements: + print(element) + q_graph, returned_elements = quantize_graph.create_training_graph( + graph, elements=elements) + # Make sure q_graph is different from graph. + self.assertTrue(graph != q_graph) + # Check that the returned elements are part of the new graph. + for returned_element in returned_elements: + self.assertEqual(q_graph, returned_element.graph) + # Check that the elements match with the one from the input graph. + for element, returned_element in zip(elements, returned_elements): + self.assertEqual(element.name, returned_element.name) + + # We have a lot of other tests that test the details of the rewrite, here we + # just the specific features of the quantize_graph API. + def testReturnedElementsEval(self): + graph = ops.Graph() + with graph.as_default(): + a = constant_op.constant(1.0) + b = variables.Variable(2.0) + c = a + b + elements = [a, b, c.op] + q_graph, returned_elements = quantize_graph.create_eval_graph( + graph, elements=elements) + # Make sure q_graph is different from graph. + self.assertTrue(graph != q_graph) + # Check that the returned elements are part of the new graph. + for returned_element in returned_elements: + self.assertEqual(q_graph, returned_element.graph) + # Check that the elements match with the one from the input graph. + for element, returned_element in zip(elements, returned_elements): + self.assertEqual(element.name, returned_element.name) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a32a7266a4c3ddf9a481fd9b292ab0f1812a9a --- /dev/null +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -0,0 +1,701 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Parameterized unit tests for quantizing a Tensorflow graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import quantize +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest +from tensorflow.python.training import training + +batch_norm = layers.batch_norm +conv2d = layers.conv2d +fully_connected = layers.fully_connected +separable_conv2d = layers.separable_conv2d + +_DEFAULT_BATCH_NORM_PARAMS = { + 'center': True, + 'scale': True, + 'decay': 1.0 - 0.003, + 'fused': False, +} + + +# TODO(suharshs): Use parameterized test once OSS TF supports it. +class QuantizeTest(test_util.TensorFlowTestCase): + + def _RunTestOverParameters(self, test_fn): + parameters_list = [ + # (activation, activation_op_name, with_bypass, delay) + (nn_ops.relu6, 'Relu6', False, None), + (nn_ops.relu, 'Relu', False, None), + (array_ops.identity, 'Identity', False, None), + (nn_ops.relu6, 'Relu6', False, 5000), + (nn_ops.relu, 'Relu', False, 5000), + (array_ops.identity, 'Identity', False, 5000), + (nn_ops.relu6, 'Relu6', True, None), + (nn_ops.relu, 'Relu', True, None), + (array_ops.identity, 'Identity', True, None), + (nn_ops.relu6, 'Relu6', True, 5000), + (nn_ops.relu, 'Relu', True, 5000), + (array_ops.identity, 'Identity', True, 5000) + ] + for parameters in parameters_list: + test_fn(parameters[0], parameters[1], parameters[2], parameters[3]) + + def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, + with_bypass, delay): + """Tests quantization: inputs -> Conv2d no batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + stride = 1 if with_bypass else 2 + out_depth = 3 if with_bypass else 32 + activation_fn = None if with_bypass else activation + scope = 'test/test2' if with_bypass else 'test' + node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize(graph, quant_delay=delay) + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum', + scope + '/weights/read' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + '/convolution' + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/BiasAdd' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def testQuantize_Conv2dWithoutBatchNorm(self): + self._RunTestOverParameters(self._TestQuantize_Conv2dWithoutBatchNorm) + + def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name, + with_bypass, delay): + """Tests quantization: inputs -> FC no batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, depth = 5, 256 + inputs = array_ops.zeros((batch_size, depth)) + out_depth = 256 if with_bypass else 128 + activation_fn = None if with_bypass else activation + scope = 'test/test2' if with_bypass else 'test' + node = fully_connected(inputs, out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=activation_fn, scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize(graph, quant_delay=delay) + + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum', + scope + '/weights/read' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + '/MatMul' + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/BiasAdd' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def testQuantize_FCWithoutBatchNorm(self): + self._RunTestOverParameters(self._TestQuantize_FCWithoutBatchNorm) + + def _TestQuantize_DepthwiseConv2dWithoutBatchNorm( + self, activation, activation_op_name, with_bypass, delay): + """Tests quantization: inputs -> DWConv2d no batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + stride = 1 if with_bypass else 2 + activation_fn = None if with_bypass else activation + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d(inputs, None, [5, 5], stride=stride, + depth_multiplier=1.0, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize(graph, quant_delay=delay) + + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum', + scope + '/depthwise_weights/read' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + '/depthwise' + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/BiasAdd' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def testQuantize_DepthwiseConv2dWithoutBatchNorm(self): + self._RunTestOverParameters( + self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) + + def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay): + """Tests quantization: inputs -> Conv2d with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + """ + self._testQuantize_Conv2dWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + use_ema=True) + self._testQuantize_Conv2dWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + use_ema=False) + + def testQuantize_Conv2dWithBatchNorm(self): + self._RunTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) + + def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay, use_ema): + """Tests quantization: inputs -> Conv2d with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + use_ema: Bool, when true uses EMA quantization for BN folded weights. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + stride = 1 if with_bypass else 2 + out_depth = 3 if with_bypass else 32 + scope = 'test/test2' if with_bypass else 'test' + node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, + scope=scope) + # Manually fold the batch norm. + weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0] + bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') + .outputs[0]) + mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold') + stride = [stride, stride] + conv_fold = nn_ops.convolution( + input=inputs, + filter=mul_fold, + padding='SAME', + strides=stride, + data_format='NHWC', + name=scope + '/convolution_Fold') + bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') + .outputs[0]) + add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold') + # Manually add a bypass (optionaly) and an activation. + if with_bypass: + node = math_ops.add(inputs, add_fold, name='test/Add') + else: + node = add_fold + node = activation(node, name='test/' + activation_op_name) + + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize( + graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), + scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), + scope + '/mul_fold' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' + if (delay and use_ema) else '/convolution_Fold') + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/add_fold' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay): + """Tests quantization: inputs -> FC with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + """ + self._testQuantize_FCWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + use_ema=True) + self._testQuantize_FCWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + use_ema=False) + + def testQuantize_FCWithBatchNorm(self): + self._RunTestOverParameters(self._TestQuantize_FCWithBatchNorm) + + def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay, use_ema): + """Tests quantization: inputs -> FC with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + use_ema: Bool, when true uses EMA quantization for BN folded weights. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, depth = 5, 256 + inputs = array_ops.zeros((batch_size, depth)) + out_depth = 256 if with_bypass else 128 + scope = 'test/test2' if with_bypass else 'test' + node = fully_connected(inputs, out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, + scope=scope) + # Manually fold the batch norm. + weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0] + bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') + .outputs[0]) + mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold') + fc_fold = math_ops.matmul(inputs, mul_fold, name=scope + '/MatMul_Fold') + bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') + .outputs[0]) + add_fold = math_ops.add(fc_fold, bn_bias, name=scope + '/add_fold') + # Manually add a bypass (optionaly) and an activation. + if with_bypass: + node = math_ops.add(inputs, add_fold, name='test/Add') + else: + node = add_fold + node = activation(node, name='test/' + activation_op_name) + + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize( + graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), + scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), + scope + '/mul_fold' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' + if delay and use_ema else '/MatMul_Fold') + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/add_fold' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def _TestQuantize_DepthwiseConv2dWithBatchNorm( + self, activation, activation_op_name, with_bypass, delay): + """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + """ + self._testQuantize_DepthwiseConv2dWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + use_ema=True) + self._testQuantize_DepthwiseConv2dWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + use_ema=False) + + def testQuantize_DepthwiseConv2dWithBatchNorm(self): + self._RunTestOverParameters( + self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) + + def _testQuantize_DepthwiseConv2dWithBatchNorm( + self, activation, activation_op_name, with_bypass, delay, use_ema): + """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + use_ema: Bool, when true uses EMA quantization for BN folded weights. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + stride = 1 if with_bypass else 2 + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d(inputs, None, [5, 5], stride=stride, + depth_multiplier=1.0, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, + scope=scope) + # Manually fold the batch norm. + weights = (graph.get_operation_by_name(scope + '/depthwise_weights/read') + .outputs[0]) + bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') + .outputs[0]) + new_shape = [ + weights.get_shape().as_list()[2], weights.get_shape().as_list()[3] + ] + bn_mult_reshaped = array_ops.reshape( + bn_mult, new_shape, name=scope + '/gamma_reshape') + mul_fold = math_ops.multiply( + weights, bn_mult_reshaped, name=scope + '/mul_fold') + stride = [1, stride, stride, 1] + conv_fold = nn_ops.depthwise_conv2d( + input=inputs, + filter=mul_fold, + padding='SAME', + strides=stride, + name=scope + '/depthwise_Fold') + bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') + .outputs[0]) + add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold') + # Manually add a bypass (optionaly) and an activation. + if with_bypass: + node = math_ops.add(inputs, add_fold, name='test/Add') + else: + node = add_fold + node = activation(node, name='test/' + activation_op_name) + + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize( + graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), + scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), + scope + '/mul_fold' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' + if delay and use_ema else '/depthwise_Fold') + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/add_fold' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def _WeightInit(self, stddev): + """Returns truncated normal variable initializer. + + Function is defined purely to shorten the name so that it stops wrapping. + + Args: + stddev: Standard deviation of normal variable. + + Returns: + An initialized that initialzes with a truncated normal variable. + """ + return init_ops.truncated_normal_initializer(stddev=stddev) + + def _AssertInputOpsAre(self, op, in_op_names): + """Asserts that all inputs to op come from in_op_names (disregarding order). + + Args: + op: Operation to check inputs for. + in_op_names: List of strings, operations where all op's inputs should + come from. + """ + expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names] + self.assertItemsEqual([t.name for t in op.inputs], expected_inputs) + + def _AssertOutputGoesToOps(self, op, graph, out_op_names): + """Asserts that outputs from op go to out_op_names (and perhaps others). + + Args: + op: Operation to check outputs for. + graph: Graph where output operations are located. + out_op_names: List of strings, operations where op's outputs should go. + """ + for out_op_name in out_op_names: + out_op = graph.get_operation_by_name(out_op_name) + self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a6bd809bb7de0b674671d09e4a941675976ce8ab --- /dev/null +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -0,0 +1,92 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for quantizing a Tensorflow graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import quantize +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + +conv2d = layers.conv2d + + +class QuantizeTest(test_util.TensorFlowTestCase): + + def testInsertQuantOpFailsWhenOpsNotConnected(self): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + conv = conv2d(inputs, 32, [5, 5], stride=2, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, scope='test') + relu = nn_ops.relu6(inputs) + + context = quantize._QuantizeContext(graph=graph, weight_bits=8, + weight_narrow_range=True, + activation_bits=8) + # Inserting a quantization op between two unconnected ops should fail with + # ValueError. + with self.assertRaises(ValueError) as err: + context._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp') + self.assertEqual( + str(err.exception), 'Some inputs not quantized for ops: [Relu6]') + + def _WeightInit(self, stddev): + """Returns truncated normal variable initializer. + + Function is defined purely to shorten the name so that it stops wrapping. + + Args: + stddev: Standard deviation of normal variable. + + Returns: + An initialized that initialzes with a truncated normal variable. + """ + return init_ops.truncated_normal_initializer(stddev=stddev) + + def _AssertInputOpsAre(self, op, in_op_names): + """Asserts that all inputs to op come from in_op_names (disregarding order). + + Args: + op: Operation to check inputs for. + in_op_names: List of strings, operations where all op's inputs should + come from. + """ + expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names] + self.assertItemsEqual([t.name for t in op.inputs], expected_inputs) + + def _AssertOutputGoesToOps(self, op, graph, out_op_names): + """Asserts that outputs from op go to out_op_names (and perhaps others). + + Args: + op: Operation to check outputs for. + graph: Graph where output operations are located. + out_op_names: List of strings, operations where op's outputs should go. + """ + for out_op_name in out_op_names: + out_op = graph.get_operation_by_name(out_op_name) + self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py index db190a1a41668bff3f6db1c674192980db068838..8b34465d21d14508c24056b588f2533d8fea6a1d 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py @@ -27,13 +27,15 @@ import math from tensorflow.contrib.receptive_field.python.util import graph_compute_order from tensorflow.contrib.util import make_ndarray from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.framework import ops as framework_ops +import numpy as np # White-listed layer operations, which do not affect the receptive field # computation. _UNCHANGED_RF_LAYER_OPS = [ - "Softplus", "Relu", "BiasAdd", "Mul", "Add", "Const", "Identity", - "VariableV2", "Sub", "Rsqrt", "ConcatV2" -] + 'Add', 'BiasAdd', 'Ceil', 'ConcatV2', 'Const', 'Floor', 'Identity', 'Log', + 'Mul', 'Pow', 'RealDiv', 'Relu', 'Round', 'Rsqrt', 'Softplus', 'Sub', + 'VariableV2'] # Different ways in which padding modes may be spelled. _VALID_PADDING = ["VALID", b"VALID"] @@ -238,7 +240,8 @@ def _get_layer_params(node, name_to_order_node): padding_x = 0 padding_y = 0 else: - raise ValueError("Unknown layer op: %s" % node.op) + raise ValueError("Unknown layer for operation '%s': %s" % + (node.name, node.op)) return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y @@ -304,13 +307,103 @@ def _get_effective_padding_node_input(stride, padding, return stride * effective_padding_output + padding -def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): - """Computes receptive field (RF) parameters from a GraphDef object. +class ReceptiveField: + """ + Receptive field of a convolutional neural network. + + Args: + size: Receptive field size. + stride: Effective stride. + padding: Effective padding. + """ + def __init__(self, size, stride, padding): + self.size = np.asarray(size) + self.stride = np.asarray(stride) + self.padding = np.asarray(padding) + + def compute_input_center_coordinates(self, y, axis=None): + """ + Computes the center of the receptive field that generated a feature. + + Args: + y: An array of feature coordinates with shape `(..., d)`, where `d` is the + number of dimensions of the coordinates. + axis: The dimensions for which to compute the input center coordinates. + If `None` (the default), compute the input center coordinates for all + dimensions. + + Returns: + x: Center of the receptive field that generated the features, at the input + of the network. + + Raises: + ValueError: If the number of dimensions of the feature coordinates does + not match the number of elements in `axis`. + """ + # Use all dimensions. + if axis is None: + axis = range(self.size.size) + # Ensure axis is a list because tuples have different indexing behavior. + axis = list(axis) + y = np.asarray(y) + if y.shape[-1] != len(axis): + raise ValueError("Dimensionality of the feature coordinates `y` (%d) " + "does not match dimensionality of `axis` (%d)" % + (y.shape[-1], len(axis))) + return - self.padding[axis] + y * self.stride[axis] + \ + (self.size[axis] - 1) / 2 + + def compute_feature_coordinates(self, x, axis=None): + """ + Computes the position of a feature given the center of a receptive field. + + Args: + x: An array of input center coordinates with shape `(..., d)`, where `d` + is the number of dimensions of the coordinates. + axis: The dimensions for which to compute the feature coordinates. + If `None` (the default), compute the feature coordinates for all + dimensions. + + Returns: + y: Coordinates of the features. + + Raises: + ValueError: If the number of dimensions of the input center coordinates + does not match the number of elements in `axis`. + """ + # Use all dimensions. + if axis is None: + axis = range(self.size.size) + # Ensure axis is a list because tuples have different indexing behavior. + axis = list(axis) + x = np.asarray(x) + if x.shape[-1] != len(axis): + raise ValueError("Dimensionality of the input center coordinates `x` " + "(%d) does not match dimensionality of `axis` (%d)" % + (x.shape[-1], len(axis))) + return (x + self.padding[axis] + (1 - self.size[axis]) / 2) / \ + self.stride[axis] + + def __iter__(self): + return iter(np.concatenate([self.size, self.stride, self.padding])) + + +def compute_receptive_field_from_graph_def(graph_def, input_node, output_node, + stop_propagation=None): + """Computes receptive field (RF) parameters from a Graph or GraphDef object. + + The algorithm stops the calculation of the receptive field whenever it + encounters an operation in the list `stop_propagation`. Stopping the + calculation early can be useful to calculate the receptive field of a + subgraph such as a single branch of the + [inception network](https://arxiv.org/abs/1512.00567). Args: - graph_def: GraphDef object. - input_node: Name of the input node from graph. - output_node: Name of the output node from graph. + graph_def: Graph or GraphDef object. + input_node: Name of the input node or Tensor object from graph. + output_node: Name of the output node or Tensor object from graph. + stop_propagation: List of operation or scope names for which to stop the + propagation of the receptive field. Returns: rf_size_x: Receptive field size of network in the horizontal direction, with @@ -331,6 +424,18 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): cannot be found. For network criterion alignment, see photos/vision/features/delf/g3doc/rf_computation.md """ + # Convert a graph to graph_def if necessary. + if isinstance(graph_def, framework_ops.Graph): + graph_def = graph_def.as_graph_def() + + # Convert tensors to names. + if isinstance(input_node, framework_ops.Tensor): + input_node = input_node.op.name + if isinstance(output_node, framework_ops.Tensor): + output_node = output_node.op.name + + stop_propagation = stop_propagation or [] + # Computes order of computation for a given graph. name_to_order_node = graph_compute_order.get_compute_order( graph_def=graph_def) @@ -422,6 +527,10 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): # Loop over this node's inputs and potentially propagate information down. for inp_name in node.input: + # Stop the propagation of the receptive field. + if any(inp_name.startswith(stop) for stop in stop_propagation): + logging.vlog(3, "Skipping explicitly ignored node %s.", node.name) + continue logging.vlog(4, "inp_name = %s", inp_name) inp_node = name_to_order_node[inp_name].node logging.vlog(4, "inp_node = \n%s", inp_node) @@ -480,6 +589,7 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): raise ValueError("Output node was not found") if input_node not in rf_sizes_x: raise ValueError("Input node was not found") - return (rf_sizes_x[input_node], rf_sizes_y[input_node], - effective_strides_x[input_node], effective_strides_y[input_node], - effective_paddings_x[input_node], effective_paddings_y[input_node]) + return ReceptiveField( + (rf_sizes_x[input_node], rf_sizes_y[input_node]), + (effective_strides_x[input_node], effective_strides_y[input_node]), + (effective_paddings_x[input_node], effective_paddings_y[input_node])) diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py index 2771389250b1518f33ebadf3f1cfd23e653dab93..8d7d5440f630a3a78749e04a5eb058d637c258fc 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn from tensorflow.python.platform import test +import numpy as np def create_test_network_1(): @@ -150,6 +151,31 @@ def create_test_network_5(): return g +def create_test_network_6(): + """Aligned network with dropout for test. + + The graph is similar to create_test_network_1(), except that the right branch + has dropout normalization. + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Left branch. + l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID') + # Right branch. + l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]]) + l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID') + l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID') + dropout = slim.dropout(l3) + # Addition. + nn.relu(l1 + dropout, name='output') + return g + + class RfUtilsTest(test.TestCase): def testComputeRFFromGraphDefAligned(self): @@ -220,6 +246,36 @@ class RfUtilsTest(test.TestCase): self.assertEqual(effective_padding_x, 0) self.assertEqual(effective_padding_y, 0) + def testComputeRFFromGraphDefStopPropagation(self): + graph_def = create_test_network_6().as_graph_def() + input_node = 'input_image' + output_node = 'output' + # Compute the receptive field but stop the propagation for the random + # uniform variable of the dropout. + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y) = ( + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node, + ['Dropout/dropout/random_uniform'])) + self.assertEqual(receptive_field_x, 3) + self.assertEqual(receptive_field_y, 3) + self.assertEqual(effective_stride_x, 4) + self.assertEqual(effective_stride_y, 4) + self.assertEqual(effective_padding_x, 1) + self.assertEqual(effective_padding_y, 1) + + def testComputeCoordinatesRoundtrip(self): + graph_def = create_test_network_1() + input_node = 'input_image' + output_node = 'output' + rf = receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node) + + x = np.random.randint(0, 100, (50, 2)) + y = rf.compute_feature_coordinates(x) + x2 = rf.compute_input_center_coordinates(y) + + self.assertAllEqual(x, x2) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 3e6c09662fe8b54ff4c07175cbba99b87e27969c..6395cd8316551336ead99a13594ad1919341c9cd 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -24,6 +24,22 @@ load( "tf_kernel_tests_linkstatic", ) +cc_library( + name = "all_ops", + deps = [ + ":gru_ops_op_lib", + ":lstm_ops_op_lib", + ], +) + +cc_library( + name = "all_kernels", + deps = [ + ":gru_ops_kernels", + ":lstm_ops_kernels", + ], +) + tf_custom_op_py_library( name = "rnn_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]) + [ @@ -34,14 +50,13 @@ tf_custom_op_py_library( ":python/ops/_lstm_ops.so", ], kernels = [ - ":gru_ops_kernels", - ":lstm_ops_kernels", - ":gru_ops_op_lib", - ":lstm_ops_op_lib", + ":all_ops", + ":all_kernels", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":benchmarking", ":gru_ops", ":lstm_ops", "//tensorflow/contrib/compiler:compiler_py", @@ -125,6 +140,9 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = [ + "optonly", + ], ) cuda_py_tests( @@ -386,3 +404,13 @@ py_test( "//tensorflow/python:variables", ], ) + +py_library( + name = "benchmarking", + srcs = ["python/kernel_tests/benchmarking.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc index ffeb9953c576b76a613e1ba46ff087c184acd532..941a457fd3ada312b981fb23c769ff9ecea9ff13 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc @@ -41,6 +41,142 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { +template +void LSTMBlockCellFpropWithEigen( + const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d, + const T forget_bias, const T cell_clip, bool use_peephole, + typename TTypes::ConstMatrix x, typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::Matrix xh, typename TTypes::Matrix i, + typename TTypes::Matrix cs, typename TTypes::Matrix f, + typename TTypes::Matrix o, typename TTypes::Matrix ci, + typename TTypes::Matrix co, typename TTypes::Matrix icfo, + typename TTypes::Matrix h) { + // Concat xh = [x, h]. + xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x; + xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev; + + // states1 = xh * w + b + typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); + TensorBlasGemm::compute( + ctx, d, false, false, T(1), const_xh, w, T(0), icfo); + Eigen::array b_shape({1, b.dimensions()[0]}); + Eigen::array broadcast_shape({cell.batch_size(), 1}); + icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape); + + Eigen::array p_shape({1, cell.cell_size()}); + Eigen::array p_broadcast_shape({cell.batch_size(), 1}); + + // Input gate. + if (use_peephole) { + auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape); + i.device(d) = + (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep) + .sigmoid(); + } else { + i.device(d) = + icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid(); + } + + // Cell input. + ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh(); + + // Forget gate (w/ bias). + if (use_peephole) { + auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + + f.constant(forget_bias) + f_peep) + .sigmoid(); + } else { + f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + + f.constant(forget_bias)) + .sigmoid(); + } + + // cs = ci .* i + f .* cs_prev + cs.device(d) = i * ci + f * cs_prev; + + if (cell_clip > 0.0f) { + cs.device(d) = + cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op()); + } + + // co = tanh(cs) + co.device(d) = cs.tanh(); + + // Output gate. + if (use_peephole) { + auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape); + o.device(d) = + (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep) + .sigmoid(); + } else { + o.device(d) = + icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid(); + } + + // h = o .* co + h.device(d) = o * co; +} + +template +void LSTMBlockCellBpropWithEigen( + const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d, + bool use_peephole, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::ConstMatrix i, typename TTypes::ConstMatrix cs, + typename TTypes::ConstMatrix f, typename TTypes::ConstMatrix o, + typename TTypes::ConstMatrix ci, typename TTypes::ConstMatrix co, + typename TTypes::ConstMatrix cs_grad, + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, + typename TTypes::Matrix df, typename TTypes::Matrix di, + typename TTypes::Matrix dicfo, typename TTypes::Matrix cs_prev_grad, + typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, + typename TTypes::Vec wco_grad) { + // do[t] = sigm'(o[t]) .* dh[t] .* co[t] + do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; + + // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] + dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; + + Eigen::array p_shape({1, cell.cell_size()}); + Eigen::array p_broadcast_shape({cell.batch_size(), 1}); + if (use_peephole) { + dcs.device(d) = + dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); + } + + // dci[t] = tanh'(ci[t]) dcs[t] i[t] + dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; + + // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] + df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; + + // di[t] = sigm'(i[t]) dcs[t] ci[t] + di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; + + dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di; + dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci; + dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df; + dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_; + + cs_prev_grad.device(d) = dcs * f; + if (use_peephole) { + cs_prev_grad.device(d) = + cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + + df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + wci_grad.device(d) = (di * cs_prev).sum(Eigen::array({0})); + wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array({0})); + wco_grad.device(d) = (do_ * cs).sum(Eigen::array({0})); + } +} + #define DEFINE_CPU_SPECS(T) \ template <> \ void LSTMBlockCellFprop::operator()( \ @@ -55,7 +191,7 @@ namespace functor { typename TTypes::Matrix f, typename TTypes::Matrix o, \ typename TTypes::Matrix ci, typename TTypes::Matrix co, \ typename TTypes::Matrix icfo, typename TTypes::Matrix h) { \ - LSTMBlockCellFpropWithEigen( \ + LSTMBlockCellFpropWithEigen( \ *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \ h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h); \ } \ diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h index 30a4b447068bf983714aea0f8da6dfecb96a5c7b..1906581b16b2e76243320bc67c8ac831323fb8e7 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.h +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h @@ -169,88 +169,6 @@ struct LSTMBlockCellFprop : public LSTMBlockCell { typename TTypes::Matrix h); }; -// TODO(b/63339763): Once GPUDevice implementation no longer relies on Eigen, -// move into lstm_ops.cc. -template -void LSTMBlockCellFpropWithEigen( - const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d, - const T forget_bias, const T cell_clip, bool use_peephole, - typename TTypes::ConstMatrix x, typename TTypes::ConstMatrix cs_prev, - typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, - typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, - typename TTypes::ConstVec wco, typename TTypes::ConstVec b, - typename TTypes::Matrix xh, typename TTypes::Matrix i, - typename TTypes::Matrix cs, typename TTypes::Matrix f, - typename TTypes::Matrix o, typename TTypes::Matrix ci, - typename TTypes::Matrix co, typename TTypes::Matrix icfo, - typename TTypes::Matrix h) { - // Concat xh = [x, h]. - xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x; - xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev; - - // states1 = xh * w + b - typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); - TensorBlasGemm::compute(ctx, d, false, false, T(1), - const_xh, w, T(0), icfo); - Eigen::array b_shape({1, b.dimensions()[0]}); - Eigen::array broadcast_shape({cell.batch_size(), 1}); - icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape); - - Eigen::array p_shape({1, cell.cell_size()}); - Eigen::array p_broadcast_shape({cell.batch_size(), 1}); - - // Input gate. - if (use_peephole) { - auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape); - i.device(d) = - (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep) - .sigmoid(); - } else { - i.device(d) = - icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid(); - } - - // Cell input. - ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh(); - - // Forget gate (w/ bias). - if (use_peephole) { - auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); - f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + - f.constant(forget_bias) + f_peep) - .sigmoid(); - } else { - f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + - f.constant(forget_bias)) - .sigmoid(); - } - - // cs = ci .* i + f .* cs_prev - cs.device(d) = i * ci + f * cs_prev; - - if (cell_clip > 0.0f) { - cs.device(d) = - cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op()); - } - - // co = tanh(cs) - co.device(d) = cs.tanh(); - - // Output gate. - if (use_peephole) { - auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape); - o.device(d) = - (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep) - .sigmoid(); - } else { - o.device(d) = - icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid(); - } - - // h = o .* co - h.device(d) = o * co; -} - // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for // GPUDevice implementation. template @@ -278,64 +196,6 @@ struct LSTMBlockCellBprop : public LSTMBlockCell { typename TTypes::Vec wco_grad); }; -// TODO(b/63339763): Once GPUDevice implementation no longer relies on Eigen, -// move into lstm_ops.cc. -template -void LSTMBlockCellBpropWithEigen( - const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d, - bool use_peephole, typename TTypes::ConstMatrix x, - typename TTypes::ConstMatrix cs_prev, - typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, - typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, - typename TTypes::ConstVec wco, typename TTypes::ConstVec b, - typename TTypes::ConstMatrix i, typename TTypes::ConstMatrix cs, - typename TTypes::ConstMatrix f, typename TTypes::ConstMatrix o, - typename TTypes::ConstMatrix ci, typename TTypes::ConstMatrix co, - typename TTypes::ConstMatrix cs_grad, - typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, - typename TTypes::Matrix dcs, typename TTypes::Matrix dci, - typename TTypes::Matrix df, typename TTypes::Matrix di, - typename TTypes::Matrix dicfo, typename TTypes::Matrix cs_prev_grad, - typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, - typename TTypes::Vec wco_grad) { - // do[t] = sigm'(o[t]) .* dh[t] .* co[t] - do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; - - // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] - dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; - - Eigen::array p_shape({1, cell.cell_size()}); - Eigen::array p_broadcast_shape({cell.batch_size(), 1}); - if (use_peephole) { - dcs.device(d) = - dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); - } - - // dci[t] = tanh'(ci[t]) dcs[t] i[t] - dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; - - // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] - df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; - - // di[t] = sigm'(i[t]) dcs[t] ci[t] - di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; - - dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di; - dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci; - dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df; - dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_; - - cs_prev_grad.device(d) = dcs * f; - if (use_peephole) { - cs_prev_grad.device(d) = - cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + - df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); - wci_grad.device(d) = (di * cs_prev).sum(Eigen::array({0})); - wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array({0})); - wco_grad.device(d) = (do_ * cs).sum(Eigen::array({0})); - } -} - template struct BlockLSTMBprop : public LSTMBlockCell { BlockLSTMBprop(const int batch_size, const int input_size, diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc index e18f8079a384b97bde6223c06944b6c5226bee21..d82676ff7e620aef765e92137a2248c9bf1deedc 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -20,15 +20,334 @@ limitations under the License. #include "tensorflow/contrib/rnn/kernels/lstm_ops.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/eigen_activations.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { namespace functor { typedef Eigen::GpuDevice GPUDevice; -// TODO(b/63339763): Provide an alternative implementation for -// LSTMBlockCell{F,B}prop that doesn't rely on Eigen. +namespace { + +// Adds bias, applies non-linearities and gates. +// +// Launch with a 2D setup such that there is one thread per (example, +// activation) with 'x' governing example index and 'y' governing activation. +// +// Launch with blocks of (batch x 32) +// +// TODO(b/67600500): Try making 'use_peephole' a template parameter. +template +__global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev, + const T* wci, const T* wcf, const T* wco, T* o, T* h, + T* ci, T* cs, T* co, T* i, T* f, const T forget_bias, + const T cell_clip, const int batch_size, + const int cell_size) { + const int batch_id = blockIdx.x * blockDim.x + threadIdx.x; + const int act_id = blockIdx.y * blockDim.y + threadIdx.y; + + if (batch_id >= batch_size || act_id >= cell_size) return; + + // The following code assumes the input arrays are of the following + // shapes and interpretations. + // + // 1) 'icfo' is a matrix such that, + // + // cell_size cell_size cell_size cell_size + // +----------+----------+----------+----------+ + // | | | | | + // | i | c | f | o | batch_size + // | | | | | + // +----------+----------+----------+----------+ + // + // 'gid' is the index assigned to this thread for 'icfo' in the 'i' submatrix. + // + // 2) 'b' is a vector such that, + // + // cell_size cell_size cell_size cell_size + // +----------+----------+----------+----------+ + // | i | c | f | o | 1 + // +----------+----------+----------+----------+ + // + // 'act_id' is the index assigned to this thread for 'b' in the 'i' subvector. + // + // 3) 'wc{i,f,o}' are vectors such that, + // + // cell_size + // +----------+ + // | i | 1 + // +----------+ + // + // 'act_id' is the index to this thread. + // + // 4) All other matrices have the form, + // + // cell_size + // +----------+ + // | | + // | i | batch_size + // | | + // +----------+ + // + // 'cid' is the index assigned to this thread. + // + const int gid = batch_id * cell_size * 4 + act_id; + const int cid = batch_id * cell_size + act_id; + Eigen::internal::scalar_sigmoid_op sigmoid_op; + Eigen::internal::scalar_tanh_op tanh_op; + Eigen::scalar_clip_op clip_op; + + T i_local; + if (use_peephole) { + i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id] + + cs_prev[cid] * wci[act_id]); + } else { + i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id]); + } + i[cid] = i_local; + + const T ci_local = + tanh_op(icfo[1 * cell_size + gid] + b[1 * cell_size + act_id]); + ci[cid] = ci_local; + + T f_local; + if (use_peephole) { + f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] + + forget_bias + cs_prev[cid] * wcf[act_id]); + } else { + f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] + + forget_bias); + } + f[cid] = f_local; + + T cs_local = i_local * ci_local + f_local * cs_prev[cid]; + if (cell_clip > 0.0) { + cs_local = clip_op(cs_local, cell_clip); + } + cs[cid] = cs_local; + + const T co_local = tanh_op(cs_local); + co[cid] = co_local; + + T o_local; + if (use_peephole) { + o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id] + + cs_local * wco[act_id]); + } else { + o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id]); + } + o[cid] = o_local; + + h[cid] = o_local * co_local; +} + +// Concatenate 'x' and 'h' and copy their contents into 'xh'. +template +__global__ void concat_xh(T* xh, const T* x, const T* h_prev, + const int batch_size, const int cell_size, + const int input_size) { + // Assumes 'x', 'h', and 'xh' are of the following shape, + // + // input_size cell_size + // +----------+----------+ + // | | | + // | x | h | batch_size + // | | | + // +----------+----------+ + // + const int gid = blockDim.x * blockIdx.x + threadIdx.x; + const int width = input_size + cell_size; + + if (gid >= width * batch_size) return; + + const int output_row = gid / width; + const int output_col = gid % width; + + if (output_col < input_size) { // x + xh[gid] = x[output_row * input_size + output_col]; + } else { // h + xh[gid] = h_prev[output_row * cell_size + output_col - input_size]; + } +} + +template +void LSTMBlockCellFpropWithCUDA( + OpKernelContext* ctx, const GPUDevice& d, const T forget_bias, + const T cell_clip, bool use_peephole, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::Matrix xh, typename TTypes::Matrix i, + typename TTypes::Matrix cs, typename TTypes::Matrix f, + typename TTypes::Matrix o, typename TTypes::Matrix ci, + typename TTypes::Matrix co, typename TTypes::Matrix icfo, + typename TTypes::Matrix h, int batch_size, int cell_size, + int input_size) { + const cudaStream_t& cu_stream = GetCudaStream(ctx); + + // Concatenate xh = [x, h]. + // + // Each block is assigned 128 threads. Good values are in [128, 1024] and are + // divisible by 32 (the size of a warp). The number of blocks is such that + // there are enough to process all the data. + const int block_dim = 128; + const int grid_dim = + Eigen::divup(batch_size * (cell_size + input_size), block_dim); + concat_xh<<>>( + xh.data(), x.data(), h_prev.data(), batch_size, cell_size, input_size); + + // states1 = xh * w + typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); + TensorBlasGemm::compute( + ctx, d, false, false, T(1), const_xh, w, T(0), icfo); + + // Add bias, apply non-linearities and gating. + // + // Use 2D blocks. The number of threads per block is equal to x * y, where x = + // min(batch_size, 8) and y = 32. See above for guidance on number of + // threads. + dim3 block_dim_2d(min(batch_size, 8), 32); + dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast(block_dim_2d.x)), + Eigen::divup(cell_size, static_cast(block_dim_2d.y))); + + if (use_peephole) { + lstm_gates<<>>( + icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(), + wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(), + i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size); + } else { + lstm_gates<<>>( + icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(), + wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(), + i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size); + } +} + +template +__global__ void lstm_gates_bprop( + const T* cs_prev, // [batch_size, cell_size] + const T* h_prev, // [batch_size, cell_size] + const T* w, // [input_size + cell_size, 4 * cell_size] + const T* wci, // [cell_size] + const T* wcf, // [cell_size] + const T* wco, // [cell_size] + const T* b, // [4 * cell_size] + const T* i, // [batch_size, cell_size] + const T* cs, // [batch_size, cell_size] + const T* f, // [batch_size, cell_size] + const T* o, // [batch_size, cell_size] + const T* ci, // [batch_size, cell_size] + const T* co, // [batch_size, cell_size] + const T* cs_grad, // [batch_size, cell_size] + const T* h_grad, // [batch_size, cell_size] + T* do_, // [batch_size, cell_size] + T* dcs, // [batch_size, cell_size] + T* dci, // [batch_size, cell_size] + T* df, // [batch_size, cell_size] + T* di, // [batch_size, cell_size] + T* dicfo, // [input_size + cell_size, 4 * cell_size] + T* cs_prev_grad, // [batch_size, cell_size] + const int batch_size, const int cell_size, const bool use_peephole) { + const int batch_id = blockIdx.x * blockDim.x + threadIdx.x; + const int act_id = blockIdx.y * blockDim.y + threadIdx.y; + + if (batch_id >= batch_size || act_id >= cell_size) return; + + const int gid = batch_id * cell_size * 4 + act_id; + const int cid = batch_id * cell_size + act_id; + + const T one = static_cast(1.0f); + + // do[t] = sigm'(o[t]) .* dh[t] .* co[t] + const T o_local = o[cid]; + const T h_grad_local = h_grad[cid]; + const T co_local = co[cid]; + const T ci_local = ci[cid]; + const T do_local = o_local * (one - o_local) * h_grad_local * co_local; + const T i_local = i[cid]; + const T f_local = f[cid]; + + do_[cid] = do_local; + + // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] + T dcs_local = + (one - co_local * co_local) * h_grad_local * o_local + cs_grad[cid]; + if (use_peephole) { + dcs_local += do_local * wco[act_id]; + } + dcs[cid] = dcs_local; + + // dci[t] = tanh'(ci[t]) dcs[t] i[t] + const T dci_local = (one - ci_local * ci_local) * dcs_local * i_local; + dci[cid] = dci_local; + + // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] + const T df_local = f_local * (one - f_local) * dcs_local * cs_prev[cid]; + df[cid] = df_local; + + // di[t] = sigm'(i[t]) dcs[t] ci[t] + const T di_local = i_local * (one - i_local) * dcs_local * ci_local; + di[cid] = di_local; + + dicfo[gid + 0 * cell_size] = di_local; + dicfo[gid + 1 * cell_size] = dci_local; + dicfo[gid + 2 * cell_size] = df_local; + dicfo[gid + 3 * cell_size] = do_local; + + cs_prev_grad[cid] = dcs_local * f_local; + if (use_peephole) { + cs_prev_grad[cid] += di_local * wci[act_id] + df_local * wcf[act_id]; + } +} + +template +void LSTMBlockCellBpropWithCUDA( + OpKernelContext* ctx, const GPUDevice& d, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::ConstMatrix i, typename TTypes::ConstMatrix cs, + typename TTypes::ConstMatrix f, typename TTypes::ConstMatrix o, + typename TTypes::ConstMatrix ci, typename TTypes::ConstMatrix co, + typename TTypes::ConstMatrix cs_grad, + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, + typename TTypes::Matrix df, typename TTypes::Matrix di, + typename TTypes::Matrix dicfo, typename TTypes::Matrix cs_prev_grad, + typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, + typename TTypes::Vec wco_grad, const int batch_size, const int cell_size, + const bool use_peephole) { + const cudaStream_t& cu_stream = GetCudaStream(ctx); + + dim3 block_dim_2d(min(batch_size, 8), 32); + dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast(block_dim_2d.x)), + Eigen::divup(cell_size, static_cast(block_dim_2d.y))); + + lstm_gates_bprop<<>>( + cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(), + wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(), + co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(), + dci.data(), df.data(), di.data(), dicfo.data(), cs_prev_grad.data(), + batch_size, cell_size, use_peephole); + + if (use_peephole) { + Eigen::array p_shape({1, cell_size}); + Eigen::array p_broadcast_shape({batch_size, 1}); + cs_prev_grad.device(d) = + cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + + df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + wci_grad.device(d) = (di * cs_prev).sum(Eigen::array({0})); + wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array({0})); + wco_grad.device(d) = (do_ * cs).sum(Eigen::array({0})); + } +} + +} // namespace + #define DEFINE_GPU_SPECS(T) \ template struct TensorZero; \ template struct TensorUnalignedZero; \ @@ -49,9 +368,10 @@ typedef Eigen::GpuDevice GPUDevice; typename TTypes::Matrix f, typename TTypes::Matrix o, \ typename TTypes::Matrix ci, typename TTypes::Matrix co, \ typename TTypes::Matrix icfo, typename TTypes::Matrix h) { \ - LSTMBlockCellFpropWithEigen( \ - *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \ - h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h); \ + LSTMBlockCellFpropWithCUDA(ctx, d, forget_bias, cell_clip, use_peephole, \ + x, cs_prev, h_prev, w, wci, wcf, wco, b, xh, i, \ + cs, f, o, ci, co, icfo, h, batch_size_, \ + cell_size_, input_size_); \ } \ template <> \ void LSTMBlockCellBprop::operator()( \ @@ -73,10 +393,10 @@ typedef Eigen::GpuDevice GPUDevice; typename TTypes::Matrix cs_prev_grad, \ typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, \ typename TTypes::Vec wco_grad) { \ - LSTMBlockCellBpropWithEigen( \ - *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \ - i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, \ - cs_prev_grad, wci_grad, wcf_grad, wco_grad); \ + LSTMBlockCellBpropWithCUDA( \ + ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \ + cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, cs_prev_grad, wci_grad, \ + wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \ } \ template struct LSTMBlockCellFprop; \ template struct LSTMBlockCellBprop; \ diff --git a/tensorflow/contrib/rnn/python/kernel_tests/benchmarking.py b/tensorflow/contrib/rnn/python/kernel_tests/benchmarking.py new file mode 100644 index 0000000000000000000000000000000000000000..a48cd58706e72516f18098e643c0fa867d33beb2 --- /dev/null +++ b/tensorflow/contrib/rnn/python/kernel_tests/benchmarking.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library for benchmarking OpKernels.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import time + +from tensorflow.python.framework import ops + + +def device(use_gpu=False): + """TensorFlow device to assign ops to.""" + if use_gpu: + return ops.device("/gpu:0") + return ops.device("/cpu:0") + + +def seconds_per_run(op, sess, num_runs=50): + """Number of seconds taken to execute 'op' once on average.""" + for _ in range(2): + sess.run(op) + + start_time = time.time() + for _ in range(num_runs): + sess.run(op) + + end_time = time.time() + time_taken = (end_time - start_time) / num_runs + return time_taken + + +def dict_product(dicts): + """Constructs iterator over outer product of entries in a dict-of-lists. + + Example: + >>> dict_products({"a": [1,2], "b": [3, 4]}) + >>> [{"a": 1, "b": 3}, + {"a": 1, "b": 4}, + {"a": 2, "b": 3}, + {"a": 2, "b": 4}] + + Args: + dicts: dictionary with string keys and list values. + + Yields: + Individual dicts from outer product. + """ + keys, values = zip(*dicts.items()) + for config_values in itertools.product(*values): + yield dict(zip(keys, config_values)) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index f222c4745c13dc4b07fa5afa61fef5615bf0dba8..8349188f6f3ff36087e90d69a3ae3c3675ad3801 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -44,7 +44,7 @@ from tensorflow.python.framework import test_util # pylint: enable=protected-access -linear = rnn_cell_impl._linear +Linear = rnn_cell_impl._Linear # pylint: disable=invalid-name class RNNCellTest(test.TestCase): @@ -54,20 +54,20 @@ class RNNCellTest(test.TestCase): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(1.0)): x = array_ops.zeros([1, 2]) - l = linear([x], 2, False) + l = Linear([x], 2, False)([x]) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([l], {x.name: np.array([[1., 2.]])}) self.assertAllClose(res[0], [[3.0, 3.0]]) # Checks prevent you from accidentally creating a shared function. with self.assertRaises(ValueError): - l1 = linear([x], 2, False) + l1 = Linear([x], 2, False)([x]) # But you can create a new one in a new scope and share the variables. with variable_scope.variable_scope("l1") as new_scope: - l1 = linear([x], 2, False) + l1 = Linear([x], 2, False)([x]) with variable_scope.variable_scope(new_scope, reuse=True): - linear([l1], 2, False) + Linear([l1], 2, False)([l1]) self.assertEqual(len(variables_lib.trainable_variables()), 2) def testBasicRNNCell(self): @@ -141,58 +141,67 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.156736, 0.156736]]) def testBasicLSTMCell(self): - with self.test_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 8]) - cell = rnn_cell_impl.MultiRNNCell( - [ - rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) - for _ in range(2) - ], - state_is_tuple=False) - g, out_m = cell(x, m) - expected_variable_names = [ - "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME - ] - self.assertEqual( - expected_variable_names, [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], - {x.name: np.array([[1., 1.]]), - m.name: 0.1 * np.ones([1, 8])}) - self.assertEqual(len(res), 2) - variables = variables_lib.global_variables() - self.assertEqual(expected_variable_names, [v.name for v in variables]) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) - expected_mem = np.array([[ - 0.68967271, 0.68967271, 0.44848421, 0.44848421, 0.39897051, - 0.39897051, 0.24024698, 0.24024698 - ]]) - self.assertAllClose(res[1], expected_mem) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros( - [1, 3]) # Test BasicLSTMCell with input_size != num_units. - m = array_ops.zeros([1, 4]) - g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], - {x.name: np.array([[1., 1., 1.]]), - m.name: 0.1 * np.ones([1, 4])}) - self.assertEqual(len(res), 2) + for dtype in [dtypes.float16, dtypes.float32]: + np_dtype = dtype.as_numpy_dtype + with self.test_session(graph=ops.Graph()) as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2], dtype=dtype) + m = array_ops.zeros([1, 8], dtype=dtype) + cell = rnn_cell_impl.MultiRNNCell( + [ + rnn_cell_impl.BasicLSTMCell( + 2, state_is_tuple=False) + for _ in range(2) + ], + state_is_tuple=False) + self.assertEqual(cell.dtype, None) + g, out_m = cell(x, m) + # Layer infers the input type. + self.assertEqual(cell.dtype, dtype.name) + expected_variable_names = [ + "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME + ] + self.assertEqual( + expected_variable_names, + [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g, out_m], + {x.name: np.array([[1., 1.]]), + m.name: 0.1 * np.ones([1, 8])}) + self.assertEqual(len(res), 2) + variables = variables_lib.global_variables() + self.assertEqual(expected_variable_names, [v.name for v in variables]) + # The numbers in results were not calculated, this is just a + # smoke test. + self.assertAllClose( + res[0], np.array([[0.240, 0.240]], dtype=np_dtype), 1e-2) + expected_mem = np.array( + [[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]], + dtype=np_dtype) + self.assertAllClose(res[1], expected_mem, 1e-2) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test BasicLSTMCell with input_size != num_units. + x = array_ops.zeros([1, 3], dtype=dtype) + m = array_ops.zeros([1, 4], dtype=dtype) + g, out_m = rnn_cell_impl.BasicLSTMCell( + 2, state_is_tuple=False)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g, out_m], + {x.name: np.array([[1., 1., 1.]], dtype=np_dtype), + m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)}) + self.assertEqual(len(res), 2) def testBasicLSTMCellDimension0Error(self): """Tests that dimension 0 in both(x and m) shape must be equal.""" @@ -441,6 +450,17 @@ class RNNCellTest(test.TestCase): outputs, _ = cell(x, m) self.assertTrue("cpu:14159" in outputs.device.lower()) + def _retrieve_cpu_gpu_stats(self, run_metadata): + cpu_stats = None + gpu_stats = None + step_stats = run_metadata.step_stats + for ds in step_stats.dev_stats: + if "cpu:0" in ds.device[-5:].lower(): + cpu_stats = ds.node_stats + if "gpu:0" == ds.device[-5:].lower(): + gpu_stats = ds.node_stats + return cpu_stats, gpu_stats + def testDeviceWrapperDynamicExecutionNodesAreAllProperlyLocated(self): if not test.is_gpu_available(): # Can't perform this test w/o a GPU @@ -462,10 +482,7 @@ class RNNCellTest(test.TestCase): sess.run([variables_lib.global_variables_initializer()]) _ = sess.run(outputs, options=opts, run_metadata=run_metadata) - step_stats = run_metadata.step_stats - ix = 0 if gpu_dev in step_stats.dev_stats[0].device else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name]) self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) @@ -829,7 +846,8 @@ def basic_rnn_cell(inputs, state, num_units, scope=None): else: with variable_scope.variable_scope(scope, "basic_rnn_cell", [inputs, state]): - output = math_ops.tanh(linear([inputs, state], num_units, True)) + output = math_ops.tanh( + Linear([inputs, state], num_units, True)([inputs, state])) return output, output diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 40a3fb2fb0b174681252265d593de2935ee2efa2..2fa033632acb451762c60a68f659302102d6c3ab 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -2203,6 +2203,17 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): return run_metadata + def _retrieve_cpu_gpu_stats(self, run_metadata): + cpu_stats = None + gpu_stats = None + step_stats = run_metadata.step_stats + for ds in step_stats.dev_stats: + if "cpu:0" in ds.device[-5:].lower(): + cpu_stats = ds.node_stats + if "gpu:0" == ds.device[-5:].lower(): + gpu_stats = ds.node_stats + return cpu_stats, gpu_stats + def testRNNOnCPUCellOnGPU(self): if not test.is_gpu_available(): return # Test requires access to a GPU @@ -2210,10 +2221,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): gpu_dev = test.gpu_device_name() run_metadata = self._execute_rnn_on( rnn_device="/cpu:0", cell_device=gpu_dev) - step_stats = run_metadata.step_stats - ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) def _assert_in(op_str, in_stats, out_stats): self.assertTrue(any(op_str in s.node_name for s in in_stats)) @@ -2236,10 +2244,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): run_metadata = self._execute_rnn_on( rnn_device="/cpu:0", cell_device="/cpu:0", input_device=gpu_dev) - step_stats = run_metadata.step_stats - ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) def _assert_in(op_str, in_stats, out_stats): self.assertTrue(any(op_str in s.node_name for s in in_stats)) @@ -2255,10 +2260,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): gpu_dev = test.gpu_device_name() run_metadata = self._execute_rnn_on( input_device=gpu_dev) - step_stats = run_metadata.step_stats - ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) def _assert_in(op_str, in_stats, out_stats): self.assertTrue(any(op_str in s.node_name for s in in_stats)) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py index 4239e32ab93043c5054e5382e67e79047b9644bb..b865466cc75aa67fcd192f7726f65141409b896a 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py @@ -18,10 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import time - import numpy as np +from tensorflow.contrib.rnn.python.kernel_tests import benchmarking from tensorflow.contrib.rnn.python.ops import gru_ops from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -333,20 +332,6 @@ class GRUBlockCellTest(test.TestCase): #### Benchmarking GRUBlockCell vs GRUCell. -def time_taken_by_op(op, sess, num_runs=50): - """Time taken by the Op.""" - for _ in range(2): - sess.run([op]) - - start_time = time.time() - for _ in range(num_runs): - sess.run([op]) - - end_time = time.time() - time_taken = end_time - start_time - return time_taken - - def training_gru_block_vs_gru_cell(batch_size, cell_size, input_size, @@ -357,7 +342,7 @@ def training_gru_block_vs_gru_cell(batch_size, ops.reset_default_graph() with session.Session(graph=ops.Graph()) as sess: # Specify the device which is been used. - with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"): + with benchmarking.device(use_gpu): # Random initializers. seed = 1994 @@ -387,7 +372,8 @@ def training_gru_block_vs_gru_cell(batch_size, learning_rate).minimize(cost) # time for a training step. - basic_time_training = time_taken_by_op(optimizer, sess, iters) + basic_time_training = benchmarking.seconds_per_run( + optimizer, sess, iters) # Output from the basic GRU cell implementation. with vs.variable_scope("block", initializer=initializer): @@ -406,7 +392,8 @@ def training_gru_block_vs_gru_cell(batch_size, learning_rate).minimize(cost) # time for a training step. - block_time_training = time_taken_by_op(optimizer, sess, iters) + block_time_training = benchmarking.seconds_per_run( + optimizer, sess, iters) performance_training = ( basic_time_training - block_time_training) * 100 / basic_time_training @@ -429,7 +416,7 @@ def inference_gru_block_vs_gru_cell(batch_size, """Benchmark inference speed between GRUBlockCell vs GRUCell.""" ops.reset_default_graph() with session.Session(graph=ops.Graph()) as sess: - with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"): + with benchmarking.device(use_gpu): # Random initializers. seed = 1994 @@ -451,7 +438,8 @@ def inference_gru_block_vs_gru_cell(batch_size, time_major=True, dtype=dtypes.float32) sess.run([variables.global_variables_initializer()]) - basic_time_inference = time_taken_by_op(outputs_dynamic, sess, iters) + basic_time_inference = benchmarking.seconds_per_run( + outputs_dynamic, sess, iters) # Output from the block GRU cell implementation. with vs.variable_scope("block", initializer=initializer): @@ -463,7 +451,8 @@ def inference_gru_block_vs_gru_cell(batch_size, time_major=True, dtype=dtypes.float32) sess.run([variables.global_variables_initializer()]) - block_time_inference = time_taken_by_op(outputs_dynamic, sess, iters) + block_time_inference = benchmarking.seconds_per_run( + outputs_dynamic, sess, iters) performance_inference = (basic_time_inference - block_time_inference ) * 100 / basic_time_inference @@ -484,7 +473,7 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size, """Benchmark single bprop step speed between GRUBlockCell vs GRUCell.""" ops.reset_default_graph() with session.Session(graph=ops.Graph()) as sess: - with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"): + with benchmarking.device(use_gpu): initializer = init_ops.random_uniform_initializer(-1, 1, seed=1989) # Inputs x = vs.get_variable("x", [batch_size, input_size]) @@ -496,7 +485,8 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size, array_ops.identity(h)) sess.run([variables.global_variables_initializer()]) grad_output_wrt_input = gradients_impl.gradients([output], h) - basic_time_bprop = time_taken_by_op(grad_output_wrt_input, sess, iters) + basic_time_bprop = benchmarking.seconds_per_run(grad_output_wrt_input, + sess, iters) # Output from the block GRU cell implementation. with vs.variable_scope("block", initializer=initializer): @@ -504,7 +494,8 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size, array_ops.identity(h)) sess.run([variables.global_variables_initializer()]) grad_output_wrt_input = gradients_impl.gradients([output], h) - block_time_bprop = time_taken_by_op(grad_output_wrt_input, sess, iters) + block_time_bprop = benchmarking.seconds_per_run(grad_output_wrt_input, + sess, iters) performance_inference = ( basic_time_bprop - block_time_bprop) * 100 / basic_time_bprop @@ -526,23 +517,29 @@ class BenchmarkGRUBlock(test.Benchmark): print("batch_size, cell_size, input_size, time_steps, GPU, " "basic_time_training, block_time_training, performance_training[%]") iters = 10 - for use_gpu in [True, False]: - for batch_size in [1, 32, 128]: - for cell_size in [128, 512]: - for input_size in [128, 512]: - for time_steps in [50]: - basic_time, block_time = training_gru_block_vs_gru_cell( - batch_size, cell_size, input_size, time_steps, use_gpu, iters) - self.report_benchmark( - name="GRUCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % - (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=basic_time) - self.report_benchmark( - name="GRUBlockCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % - (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=block_time) + + for config in benchmarking.dict_product({ + "use_gpu": [True, False], + "batch_size": [1, 32, 128], + "cell_size": [128, 512], + "input_size": [128, 512], + "time_steps": [50] + }): + basic_time, block_time = training_gru_block_vs_gru_cell( + config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"], iters) + self.report_benchmark( + name="GRUCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=basic_time) + self.report_benchmark( + name="GRUBlockCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=block_time) def benchmarkInferenceBlockGRUVsGRUCell(self): print("--------------------------------------------------------------") @@ -551,23 +548,28 @@ class BenchmarkGRUBlock(test.Benchmark): "batch_size, cell_size, input_size, time_steps, GPU, " "basic_time_inference, block_time_inference, performance_inference[%]") iters = 10 - for use_gpu in [True, False]: - for batch_size in [1, 32, 128]: - for cell_size in [128, 512]: - for input_size in [128, 512]: - for time_steps in [50]: - basic_time, block_time = inference_gru_block_vs_gru_cell( - batch_size, cell_size, input_size, time_steps, use_gpu, iters) - self.report_benchmark( - name="GRUCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % - (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=basic_time) - self.report_benchmark( - name="GRUBlockCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" - % (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=block_time) + for config in benchmarking.dict_product({ + "use_gpu": [True, False], + "batch_size": [1, 32, 128], + "cell_size": [128, 512], + "input_size": [128, 512], + "time_steps": [50] + }): + basic_time, block_time = inference_gru_block_vs_gru_cell( + config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"], iters) + self.report_benchmark( + name="GRUCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=basic_time) + self.report_benchmark( + name="GRUBlockCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=block_time) def benchmarkSingleBpropStepBlockGRUVsGRUCell(self): print("--------------------------------------------------------------") @@ -575,22 +577,27 @@ class BenchmarkGRUBlock(test.Benchmark): print("batch_size, cell_size, input_size, GPU, basic_time, " "block_time, performance_inference[%]") iters = 10 - for use_gpu in [True, False]: - for batch_size in [1, 32, 128]: - for cell_size in [128, 512]: - for input_size in [128, 512]: - basic_time, block_time = single_bprop_step_gru_block_vs_gru_cell( - batch_size, cell_size, input_size, use_gpu, iters) - self.report_benchmark( - name="GRUCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" % - (batch_size, cell_size, input_size, use_gpu), - iters=iters, - wall_time=basic_time) - self.report_benchmark( - name="GRUBlockCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" - % (batch_size, cell_size, input_size, use_gpu), - iters=iters, - wall_time=block_time) + for config in benchmarking.dict_product({ + "use_gpu": [True, False], + "batch_size": [1, 32, 128], + "cell_size": [128, 512], + "input_size": [128, 512] + }): + basic_time, block_time = single_bprop_step_gru_block_vs_gru_cell( + config["batch_size"], config["cell_size"], config["input_size"], + config["use_gpu"], iters) + self.report_benchmark( + name="GRUCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["use_gpu"]), + iters=iters, + wall_time=basic_time) + self.report_benchmark( + name="GRUBlockCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["use_gpu"]), + iters=iters, + wall_time=block_time) print("--------------------------------------------------------------") diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py index 0ec37411f5f3d9b6687c077bf967b046068644ab..a288072ae5da0751f1999128029f38bea933490e 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py @@ -20,7 +20,9 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.rnn.python.kernel_tests import benchmarking from tensorflow.contrib.rnn.python.ops import lstm_ops +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -36,6 +38,111 @@ from tensorflow.python.platform import test block_lstm = lstm_ops._block_lstm # pylint: disable=protected-access +def blocks_match(sess, use_peephole): + batch_size = 2 + input_size = 3 + cell_size = 4 + sequence_length = 4 + + inputs = [] + for _ in range(sequence_length): + inp = ops.convert_to_tensor( + np.random.randn(batch_size, input_size), dtype=dtypes.float32) + inputs.append(inp) + + initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=19890212) + + with variable_scope.variable_scope("test", initializer=initializer): + # magic naming so that the cells pick up these variables and resuse them + if use_peephole: + wci = variable_scope.get_variable( + "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtypes.float32) + wcf = variable_scope.get_variable( + "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtypes.float32) + wco = variable_scope.get_variable( + "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtypes.float32) + + w = variable_scope.get_variable( + "rnn/lstm_cell/kernel", + shape=[input_size + cell_size, cell_size * 4], + dtype=dtypes.float32) + b = variable_scope.get_variable( + "rnn/lstm_cell/bias", + shape=[cell_size * 4], + dtype=dtypes.float32, + initializer=init_ops.zeros_initializer()) + + if use_peephole: + wci_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/w_i_diag", + initializer=wci.initialized_value()) + wcf_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/w_f_diag", + initializer=wcf.initialized_value()) + wco_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/w_o_diag", + initializer=wco.initialized_value()) + w_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/kernel", + initializer=w.initialized_value()) + b_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/bias", + initializer=b.initialized_value()) + + basic_cell = rnn_cell.LSTMCell( + cell_size, use_peepholes=use_peephole, state_is_tuple=True, reuse=True) + basic_outputs_op, basic_state_op = rnn.static_rnn( + basic_cell, inputs, dtype=dtypes.float32) + + if use_peephole: + _, _, _, _, _, _, block_outputs_op = block_lstm( + ops.convert_to_tensor(sequence_length, dtype=dtypes.int64), + inputs, + w, + b, + wci=wci, + wcf=wcf, + wco=wco, + cell_clip=0, + use_peephole=True) + else: + _, _, _, _, _, _, block_outputs_op = block_lstm( + ops.convert_to_tensor(sequence_length, dtype=dtypes.int64), + inputs, + w, + b, + cell_clip=0) + + with variable_scope.variable_scope("rnn/lstm_cell", reuse=True): + fused_cell = lstm_ops.LSTMBlockFusedCell( + cell_size, cell_clip=0, use_peephole=use_peephole) + fused_outputs_op, fused_state_op = fused_cell( + inputs, dtype=dtypes.float32) + + sess.run([variables.global_variables_initializer()]) + basic_outputs, basic_state = sess.run([basic_outputs_op, basic_state_op[0]]) + basic_grads = sess.run(gradients_impl.gradients(basic_outputs_op, inputs)) + xs = [w, b] + if use_peephole: + xs += [wci, wcf, wco] + basic_wgrads = sess.run(gradients_impl.gradients(basic_outputs_op, xs)) + + block_outputs = sess.run(block_outputs_op) + block_grads = sess.run(gradients_impl.gradients(block_outputs_op, inputs)) + block_wgrads = sess.run(gradients_impl.gradients(block_outputs_op, xs)) + + xs = [w_block, b_block] + if use_peephole: + xs += [wci_block, wcf_block, wco_block] + fused_outputs, fused_state = sess.run([fused_outputs_op, fused_state_op[0]]) + fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs)) + fused_wgrads = sess.run(gradients_impl.gradients(fused_outputs_op, xs)) + + return (basic_state, fused_state, basic_outputs, block_outputs, + fused_outputs, basic_grads, block_grads, fused_grads, basic_wgrads, + block_wgrads, fused_wgrads) + + class LSTMBlockCellTest(test.TestCase): def testNoneDimsWithDynamicRNN(self): @@ -225,164 +332,39 @@ class LSTMBlockCellTest(test.TestCase): def testLSTMBasicToBlock(self): with self.test_session(use_gpu=True) as sess: - batch_size = 2 - input_size = 3 - cell_size = 4 - sequence_length = 5 - - inputs = [] - for _ in range(sequence_length): - inp = ops.convert_to_tensor( - np.random.randn(batch_size, input_size), dtype=dtypes.float32) - inputs.append(inp) - - initializer = init_ops.random_uniform_initializer( - -0.01, 0.01, seed=19890212) - with variable_scope.variable_scope("basic", initializer=initializer): - cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True) - outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - basic_outputs, basic_state = sess.run([outputs, state[0]]) - basic_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - basic_wgrads = sess.run( - gradients_impl.gradients(outputs, variables.trainable_variables())) - - with variable_scope.variable_scope("block", initializer=initializer): - w = variable_scope.get_variable( - "w", - shape=[input_size + cell_size, cell_size * 4], - dtype=dtypes.float32) - b = variable_scope.get_variable( - "b", - shape=[cell_size * 4], - dtype=dtypes.float32, - initializer=init_ops.zeros_initializer()) - - _, _, _, _, _, _, outputs = block_lstm( - ops.convert_to_tensor( - sequence_length, dtype=dtypes.int64), - inputs, - w, - b, - cell_clip=0) - - sess.run([variables.global_variables_initializer()]) - block_outputs = sess.run(outputs) - block_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - block_wgrads = sess.run(gradients_impl.gradients(outputs, [w, b])) + (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs, + basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads, + fused_wgrads) = blocks_match( + sess, use_peephole=False) self.assertAllClose(basic_outputs, block_outputs) self.assertAllClose(basic_grads, block_grads) for basic, block in zip(basic_wgrads, block_wgrads): - self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2) - - with variable_scope.variable_scope("fused", initializer=initializer): - cell = lstm_ops.LSTMBlockFusedCell( - cell_size, cell_clip=0, use_peephole=False) - outputs, state = cell(inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - fused_outputs, fused_state = sess.run([outputs, state[0]]) - fused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - fused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("fused/") - ] - fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars)) + self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6) self.assertAllClose(basic_outputs, fused_outputs) self.assertAllClose(basic_state, fused_state) self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(basic_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) + for basic, fused in zip(block_wgrads, fused_wgrads): + self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6) def testLSTMBasicToBlockPeeping(self): with self.test_session(use_gpu=True) as sess: - batch_size = 2 - input_size = 3 - cell_size = 4 - sequence_length = 5 - - inputs = [] - for _ in range(sequence_length): - inp = ops.convert_to_tensor( - np.random.randn(batch_size, input_size), dtype=dtypes.float32) - inputs.append(inp) - - initializer = init_ops.random_uniform_initializer( - -0.01, 0.01, seed=19890212) - with variable_scope.variable_scope("basic", initializer=initializer): - cell = rnn_cell.LSTMCell( - cell_size, use_peepholes=True, state_is_tuple=True) - outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - basic_outputs, basic_state = sess.run([outputs, state[0]]) - basic_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - basic_wgrads = sess.run( - gradients_impl.gradients(outputs, variables.trainable_variables())) - - with variable_scope.variable_scope("block", initializer=initializer): - w = variable_scope.get_variable( - "w", - shape=[input_size + cell_size, cell_size * 4], - dtype=dtypes.float32) - b = variable_scope.get_variable( - "b", - shape=[cell_size * 4], - dtype=dtypes.float32, - initializer=init_ops.zeros_initializer()) - - wci = variable_scope.get_variable( - "wci", shape=[cell_size], dtype=dtypes.float32) - wcf = variable_scope.get_variable( - "wcf", shape=[cell_size], dtype=dtypes.float32) - wco = variable_scope.get_variable( - "wco", shape=[cell_size], dtype=dtypes.float32) - - _, _, _, _, _, _, outputs = block_lstm( - ops.convert_to_tensor( - sequence_length, dtype=dtypes.int64), - inputs, - w, - b, - wci=wci, - wcf=wcf, - wco=wco, - cell_clip=0, - use_peephole=True) - - sess.run([variables.global_variables_initializer()]) - block_outputs = sess.run(outputs) - block_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - block_wgrads = sess.run( - gradients_impl.gradients(outputs, [w, b, wci, wcf, wco])) + (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs, + basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads, + fused_wgrads) = blocks_match( + sess, use_peephole=True) self.assertAllClose(basic_outputs, block_outputs) self.assertAllClose(basic_grads, block_grads) for basic, block in zip(basic_wgrads, block_wgrads): - self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2) - - with variable_scope.variable_scope("fused", initializer=initializer): - cell = lstm_ops.LSTMBlockFusedCell( - cell_size, cell_clip=0, use_peephole=True) - outputs, state = cell(inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - fused_outputs, fused_state = sess.run([outputs, state[0]]) - fused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - fused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("fused/") - ] - fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars)) + self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6) self.assertAllClose(basic_outputs, fused_outputs) self.assertAllClose(basic_state, fused_state) self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(basic_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) + for basic, fused in zip(block_wgrads, fused_wgrads): + self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6) def testLSTMFusedSequenceLengths(self): """Verify proper support for sequence lengths in LSTMBlockFusedCell.""" @@ -401,45 +383,40 @@ class LSTMBlockCellTest(test.TestCase): initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=19890213) - with variable_scope.variable_scope("basic", initializer=initializer): - cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True) - outputs, state = rnn.static_rnn( - cell, inputs, dtype=dtypes.float32, sequence_length=seq_lengths) - sess.run([variables.global_variables_initializer()]) - basic_outputs, basic_state = sess.run([outputs, state[0]]) - basic_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - basic_wgrads = sess.run( - gradients_impl.gradients(outputs, variables.trainable_variables())) - with variable_scope.variable_scope("fused", initializer=initializer): + with variable_scope.variable_scope( + "lstm_block_wrapper", initializer=initializer): + # magic naming so that the cells pick up these variables and resuse them + variable_scope.get_variable( + "kernel", + shape=[input_size + cell_size, cell_size * 4], + dtype=dtypes.float32) + + variable_scope.get_variable( + "bias", + shape=[cell_size * 4], + dtype=dtypes.float32, + initializer=init_ops.zeros_initializer()) + + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): cell = lstm_ops.LSTMBlockFusedCell( cell_size, cell_clip=0, use_peephole=False) - outputs, state = cell( - inputs, dtype=dtypes.float32, sequence_length=seq_lengths) - sess.run([variables.global_variables_initializer()]) - fused_outputs, fused_state = sess.run([outputs, state[0]]) - fused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - fused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("fused/") - ] - fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars)) + fused_outputs_op, fused_state_op = cell( + inputs, dtype=dtypes.float32, sequence_length=seq_lengths) - self.assertAllClose(basic_outputs, fused_outputs) - self.assertAllClose(basic_state, fused_state) - self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(basic_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) + cell_vars = [ + v for v in variables.trainable_variables() + if v.name.endswith("kernel") or v.name.endswith("bias") + ] # Verify that state propagation works if we turn our sequence into # tiny (single-time) subsequences, i.e. unfuse the cell + unfused_outputs_op = [] + state = None with variable_scope.variable_scope( - "unfused", initializer=initializer) as vs: - cell = lstm_ops.LSTMBlockFusedCell( - cell_size, cell_clip=0, use_peephole=False) - outputs = [] - state = None + variable_scope.get_variable_scope(), reuse=True): for i, inp in enumerate(inputs): lengths = [int(i < l) for l in seq_lengths.eval()] output, state = cell( @@ -447,25 +424,136 @@ class LSTMBlockCellTest(test.TestCase): initial_state=state, dtype=dtypes.float32, sequence_length=lengths) - vs.reuse_variables() - outputs.append(output[0]) - outputs = array_ops.stack(outputs) - - sess.run([variables.global_variables_initializer()]) - unfused_outputs, unfused_state = sess.run([outputs, state[0]]) - unfused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - unfused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("unfused/") - ] - unfused_wgrads = sess.run( - gradients_impl.gradients(outputs, unfused_vars)) - - self.assertAllClose(basic_outputs, unfused_outputs) - self.assertAllClose(basic_state, unfused_state) - self.assertAllClose(basic_grads, unfused_grads) - for basic, unfused in zip(basic_wgrads, unfused_wgrads): - self.assertAllClose(basic, unfused, rtol=1e-2, atol=1e-2) + unfused_outputs_op.append(output[0]) + unfused_outputs_op = array_ops.stack(unfused_outputs_op) + + sess.run([variables.global_variables_initializer()]) + unfused_outputs, unfused_state = sess.run([unfused_outputs_op, state[0]]) + unfused_grads = sess.run( + gradients_impl.gradients(unfused_outputs_op, inputs)) + unfused_wgrads = sess.run( + gradients_impl.gradients(unfused_outputs_op, cell_vars)) + + fused_outputs, fused_state = sess.run( + [fused_outputs_op, fused_state_op[0]]) + fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs)) + fused_wgrads = sess.run( + gradients_impl.gradients(fused_outputs_op, cell_vars)) + + self.assertAllClose(fused_outputs, unfused_outputs) + self.assertAllClose(fused_state, unfused_state) + self.assertAllClose(fused_grads, unfused_grads) + for fused, unfused in zip(fused_wgrads, unfused_wgrads): + self.assertAllClose(fused, unfused, rtol=1e-6, atol=1e-6) + +#### Benchmarking. + + +class BenchmarkLSTMBlock(test.Benchmark): + + def benchmarkLSTMBlockCellFpropWithDynamicRNN(self): + print("BlockLSTMCell forward propagation via dynamic_rnn().") + print("--------------------------------------------------------------") + print("LSTMBlockCell Seconds per inference.") + print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time") + iters = 10 + for config in benchmarking.dict_product({ + "batch_size": [1, 8, 13, 32, 67, 128], + "cell_size": [128, 250, 512, 650, 1024, 1350], + "time_steps": [40], + "use_gpu": [True, False] + }): + with ops.Graph().as_default(): + with benchmarking.device(use_gpu=config["use_gpu"]): + inputs = variable_scope.get_variable( + "x", + [config["time_steps"], config["batch_size"], config["cell_size"]]) + cell = lstm_ops.LSTMBlockCell(config["cell_size"]) + outputs = rnn.dynamic_rnn( + cell, inputs, time_major=True, dtype=dtypes.float32) + init_op = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init_op) + wall_time = benchmarking.seconds_per_run(outputs, sess, iters) + + # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable + # is set, this will produce a copy-paste-able CSV file. + print(",".join( + map(str, [ + config["batch_size"], config["cell_size"], config["cell_size"], + config["time_steps"], config["use_gpu"], wall_time + ]))) + benchmark_name_template = "_".join([ + "LSTMBlockCell_fprop", "BS%(batch_size)i", "CS%(cell_size)i", + "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s" + ]) + + self.report_benchmark( + name=benchmark_name_template % config, + iters=iters, + wall_time=wall_time, + extras=config) + + def benchmarkLSTMBlockCellBpropWithDynamicRNN(self): + print("BlockLSTMCell backward propagation via dynamic_rnn().") + print("--------------------------------------------------------------") + print("LSTMBlockCell Seconds per inference.") + print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time") + iters = 10 + for config in benchmarking.dict_product({ + "batch_size": [1, 8, 13, 32, 67, 128], + "cell_size": [128, 250, 512, 650, 1024, 1350], + "time_steps": [40], + "use_gpu": [True, False] + }): + with ops.Graph().as_default(): + with benchmarking.device(use_gpu=config["use_gpu"]): + time_steps = config["time_steps"] + batch_size = config["batch_size"] + cell_size = input_size = config["cell_size"] + inputs = variable_scope.get_variable( + "x", [time_steps, batch_size, cell_size], + trainable=False, + dtype=dtypes.float32) + with variable_scope.variable_scope( + "rnn", reuse=variable_scope.AUTO_REUSE): + w = variable_scope.get_variable( + "rnn/lstm_cell/kernel", + shape=[input_size + cell_size, cell_size * 4], + dtype=dtypes.float32) + b = variable_scope.get_variable( + "rnn/lstm_cell/bias", + shape=[cell_size * 4], + dtype=dtypes.float32, + initializer=init_ops.zeros_initializer()) + cell = lstm_ops.LSTMBlockCell(cell_size) + outputs = rnn.dynamic_rnn( + cell, inputs, time_major=True, dtype=dtypes.float32) + grads = gradients_impl.gradients(outputs, [inputs, w, b]) + init_op = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init_op) + wall_time = benchmarking.seconds_per_run(grads, sess, iters) + + # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable + # is set, this will produce a copy-paste-able CSV file. + print(",".join( + map(str, [ + batch_size, cell_size, cell_size, time_steps, config["use_gpu"], + wall_time + ]))) + benchmark_name_template = "_".join([ + "LSTMBlockCell_bprop", "BS%(batch_size)i", "CS%(cell_size)i", + "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s" + ]) + + self.report_benchmark( + name=benchmark_name_template % config, + iters=iters, + wall_time=wall_time, + extras=config) if __name__ == "__main__": diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py index 6b6bd503ceec8d0d7cd2bca5b7ec548fcf08445c..f877e4dacbf23df51e0f9231de60443bdce7b42c 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging RNNCell = rnn_cell_impl.RNNCell # pylint: disable=invalid-name -_linear = rnn_cell_impl._linear # pylint: disable=invalid-name, protected-access +_Linear = rnn_cell_impl._Linear # pylint: disable=invalid-name, protected-access _like_rnncell = rnn_cell_impl._like_rnncell # pylint: disable=invalid-name, protected-access @@ -154,6 +154,7 @@ class InputProjectionWrapper(RNNCell): self._cell = cell self._num_proj = num_proj self._activation = activation + self._linear = None @property def state_size(self): @@ -170,7 +171,9 @@ class InputProjectionWrapper(RNNCell): def call(self, inputs, state): """Run the input projection and then the cell.""" # Default scope: "InputProjectionWrapper" - projected = _linear(inputs, self._num_proj, True) + if self._linear is None: + self._linear = _Linear(inputs, self._num_proj, True) + projected = self._linear(inputs) if self._activation: projected = self._activation(projected) return self._cell(projected, state) @@ -208,6 +211,7 @@ class OutputProjectionWrapper(RNNCell): self._cell = cell self._output_size = output_size self._activation = activation + self._linear = None @property def state_size(self): @@ -224,7 +228,9 @@ class OutputProjectionWrapper(RNNCell): def call(self, inputs, state): """Run the cell and output projection on inputs, starting from state.""" output, res_state = self._cell(inputs, state) - projected = _linear(output, self._output_size, True) + if self._linear is None: + self._linear = _Linear(output, self._output_size, True) + projected = self._linear(output) if self._activation: projected = self._activation(projected) return projected, res_state diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index f591f7c84e50660ccddbe13e31a32f6bc273c460..df910a3423083972bdee42bec10733e37b8e5f96 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -92,7 +92,7 @@ def _lstm_block_cell(x, wco: A `Tensor`. Must have the same type as `x`. The weight matrix for output gate peephole connection. forget_bias: An optional `float`. Defaults to `1`. The forget gate bias. - cell_clip: An optional `float`. Defaults to `3`. + cell_clip: An optional `float`. Defaults to `-1` (no clipping). Value to clip the 'cs' value to. Disable by setting to negative value. use_peephole: An optional `bool`. Defaults to `False`. Whether to use peephole weights. @@ -116,8 +116,8 @@ def _lstm_block_cell(x, if cell_size is None: raise ValueError("cell_size from `cs_prev` should not be None.") wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) - wco = wci wcf = wci + wco = wci # pylint: disable=protected-access return gen_lstm_ops.lstm_block_cell( @@ -126,11 +126,11 @@ def _lstm_block_cell(x, h_prev=h_prev, w=w, wci=wci, - wco=wco, wcf=wcf, + wco=wco, b=b, forget_bias=forget_bias, - cell_clip=cell_clip, + cell_clip=cell_clip if cell_clip is not None else -1, use_peephole=use_peephole, name=name) # pylint: enable=protected-access @@ -162,7 +162,7 @@ def _block_lstm(seq_len_max, wcf: A `Tensor`. Must have the same type as `x`. wco: A `Tensor`. Must have the same type as `x`. forget_bias: An optional `float`. Defaults to `1`. - cell_clip: An optional `float`. Defaults to `3`. + cell_clip: An optional `float`. Defaults to `-1` (no clipping). use_peephole: An optional `bool`. Defaults to `False`. name: A name for the operation (optional). @@ -201,8 +201,8 @@ def _block_lstm(seq_len_max, h_prev = zero_state if wci is None: wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) - wco = wci wcf = wci + wco = wci # pylint: disable=protected-access i, cs, f, o, ci, co, h = gen_lstm_ops.block_lstm( @@ -212,11 +212,11 @@ def _block_lstm(seq_len_max, h_prev=h_prev, w=w, wci=wci, - wco=wco, wcf=wcf, + wco=wco, b=b, forget_bias=forget_bias, - cell_clip=cell_clip, + cell_clip=cell_clip if cell_clip is not None else -1, name=name, use_peephole=use_peephole) @@ -233,7 +233,7 @@ _lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"] @ops.RegisterGradient("LSTMBlockCell") def _LSTMBlockCellGrad(op, *grad): """Gradient for LSTMBlockCell.""" - (x, cs_prev, h_prev, w, wci, wco, wcf, b) = op.inputs + (x, cs_prev, h_prev, w, wci, wcf, wco, b) = op.inputs (i, cs, f, o, ci, co, _) = op.outputs (_, cs_grad, _, _, _, _, h_grad) = grad @@ -293,13 +293,13 @@ def _LSTMBlockCellGrad(op, *grad): @ops.RegisterGradient("BlockLSTM") def _BlockLSTMGrad(op, *grad): """Gradient for BlockLSTM.""" - seq_len_max, x, cs_prev, h_prev, w, wci, wco, wcf, b = op.inputs + seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs i, cs, f, o, ci, co, h = op.outputs cs_grad = grad[1] h_grad = grad[6] - (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, wcf_grad, + (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad, b_grad) = gen_lstm_ops.block_lstm_grad( seq_len_max, x, @@ -307,8 +307,8 @@ def _BlockLSTMGrad(op, *grad): h_prev, w, wci, - wco, wcf, + wco, b, i, cs, @@ -321,8 +321,10 @@ def _BlockLSTMGrad(op, *grad): h_grad, use_peephole=op.get_attr("use_peephole")) - return [None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, - wcf_grad, b_grad] + return [ + None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, + wco_grad, b_grad + ] class LSTMBlockCell(rnn_cell_impl.RNNCell): @@ -341,7 +343,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): def __init__(self, num_units, forget_bias=1.0, - clip_cell=True, + cell_clip=None, use_peephole=False, reuse=None): """Initialize the basic LSTM cell. @@ -349,8 +351,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). - clip_cell: boolean, whether to apply cell clipping. See - `_lstm_block_cell()` for details. + cell_clip: An optional `float`. Defaults to `-1` (no clipping). use_peephole: Whether to use peephole connections or not. reuse: (optional) boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the @@ -363,13 +364,13 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): self._num_units = num_units self._forget_bias = forget_bias self._use_peephole = use_peephole - self._clip_cell = clip_cell + self._cell_clip = cell_clip if cell_clip is not None else -1 self._names = { "W": "kernel", "b": "bias", "wci": "w_i_diag", - "wco": "w_o_diag", "wcf": "w_f_diag", + "wco": "w_o_diag", "scope": "lstm_cell" } @@ -397,10 +398,10 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): initializer=init_ops.constant_initializer(0.0)) if self._use_peephole: wci = vs.get_variable(self._names["wci"], [self._num_units]) - wco = vs.get_variable(self._names["wco"], [self._num_units]) wcf = vs.get_variable(self._names["wcf"], [self._num_units]) + wco = vs.get_variable(self._names["wco"], [self._num_units]) else: - wci = wco = wcf = array_ops.zeros([self._num_units]) + wci = wcf = wco = array_ops.zeros([self._num_units]) (cs_prev, h_prev) = states_prev (_, cs, _, _, _, _, h) = _lstm_block_cell( x, @@ -409,10 +410,10 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): w, b, wci=wci, - wco=wco, wcf=wcf, + wco=wco, forget_bias=self._forget_bias, - cell_clip=None if self._clip_cell else -1, + cell_clip=self._cell_clip, use_peephole=self._use_peephole) new_state = rnn_cell_impl.LSTMStateTuple(cs, h) @@ -594,12 +595,12 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). - cell_clip: clip the cell to this value. Defaults to `3`. + cell_clip: clip the cell to this value. Default is no cell clipping. use_peephole: Whether to use peephole connections or not. """ self._num_units = num_units self._forget_bias = forget_bias - self._cell_clip = cell_clip + self._cell_clip = cell_clip if cell_clip is not None else -1 self._use_peephole = use_peephole @property @@ -645,10 +646,10 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): dtype=dtype) if self._use_peephole: wci = vs.get_variable("w_i_diag", [self._num_units], dtype=dtype) - wco = vs.get_variable("w_o_diag", [self._num_units], dtype=dtype) wcf = vs.get_variable("w_f_diag", [self._num_units], dtype=dtype) + wco = vs.get_variable("w_o_diag", [self._num_units], dtype=dtype) else: - wci = wco = wcf = array_ops.zeros([self._num_units], dtype=dtype) + wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype) if sequence_length is None: max_seq_len = math_ops.to_int64(time_len) @@ -662,8 +663,8 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): h_prev=initial_output, w=w, wci=wci, - wco=wco, wcf=wcf, + wco=wco, b=b, forget_bias=self._forget_bias, cell_clip=self._cell_clip, diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 7b28222257f27b7e95f4215f5331eb475110dbb2..6702a89d22283be691deb11339d3374bf8c4fd93 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -525,7 +525,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): self._state_tuple_type = collections.namedtuple( "GridLSTMStateTuple", state_names.strip(",")) self._state_size = self._state_tuple_type( - *([num_units, num_units] * self._total_blocks)) + *([num_units, num_units] * self._total_blocks)) else: self._state_tuple_type = None self._state_size = num_units * self._total_blocks * 2 @@ -1017,7 +1017,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell): # pylint: disable=protected-access -_linear = rnn_cell_impl._linear +_Linear = rnn_cell_impl._Linear # pylint: disable=invalid-name # pylint: enable=protected-access @@ -1079,6 +1079,9 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): self._attn_size = attn_size self._attn_length = attn_length self._reuse = reuse + self._linear1 = None + self._linear2 = None + self._linear3 = None @property def state_size(self): @@ -1110,7 +1113,9 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): input_size = self._input_size if input_size is None: input_size = inputs.get_shape().as_list()[1] - inputs = _linear([inputs, attns], input_size, True) + if self._linear1 is None: + self._linear1 = _Linear([inputs, attns], input_size, True) + inputs = self._linear1([inputs, attns]) cell_output, new_state = self._cell(inputs, state) if self._state_is_tuple: new_state_cat = array_ops.concat(nest.flatten(new_state), 1) @@ -1118,7 +1123,9 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): new_state_cat = new_state new_attns, new_attn_states = self._attention(new_state_cat, attn_states) with vs.variable_scope("attn_output_projection"): - output = _linear([cell_output, new_attns], self._attn_size, True) + if self._linear2 is None: + self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True) + output = self._linear2([cell_output, new_attns]) new_attn_states = array_ops.concat( [new_attn_states, array_ops.expand_dims(output, 1)], 1) new_attn_states = array_ops.reshape( @@ -1141,7 +1148,9 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): hidden = array_ops.reshape(attn_states, [-1, self._attn_length, 1, self._attn_size]) hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") - y = _linear(query, self._attn_vec_size, True) + if self._linear3 is None: + self._linear3 = _Linear(query, self._attn_vec_size, True) + y = self._linear3(query) y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size]) s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) a = softmax(s) @@ -1537,6 +1546,7 @@ class UGRNNCell(rnn_cell_impl.RNNCell): self._forget_bias = forget_bias self._activation = activation self._reuse = reuse + self._linear = None @property def state_size(self): @@ -1573,7 +1583,9 @@ class UGRNNCell(rnn_cell_impl.RNNCell): with vs.variable_scope(vs.get_variable_scope(), initializer=self._initializer): cell_inputs = array_ops.concat([inputs, state], 1) - rnn_matrix = _linear(cell_inputs, 2 * self._num_units, True) + if self._linear is None: + self._linear = _Linear(cell_inputs, 2 * self._num_units, True) + rnn_matrix = self._linear(cell_inputs) [g_act, c_act] = array_ops.split( axis=1, num_or_size_splits=2, value=rnn_matrix) @@ -1638,6 +1650,8 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell): self._num_input_proj = num_in_proj self._y_activation = y_activation self._reuse = reuse + self._linear1 = None + self._linear2 = None @property def state_size(self): @@ -1680,7 +1694,9 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell): if input_size.value != self._num_units: if self._num_input_proj: with vs.variable_scope("in_projection"): - inputs = _linear(inputs, self._num_units, True) + if self._linear1 is None: + self._linear1 = _Linear(inputs, self._num_units, True) + inputs = self._linear1(inputs) else: raise ValueError("Must have input size == output size for " "Intersection RNN. To fix, num_in_proj should " @@ -1688,7 +1704,9 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell): n_dim = i_dim = self._num_units cell_inputs = array_ops.concat([inputs, state], 1) - rnn_matrix = _linear(cell_inputs, 2*n_dim + 2*i_dim, True) + if self._linear2 is None: + self._linear2 = _Linear(cell_inputs, 2*n_dim + 2*i_dim, True) + rnn_matrix = self._linear2(cell_inputs) gh_act = rnn_matrix[:, :n_dim] # b x n h_act = rnn_matrix[:, n_dim:2*n_dim] # b x n @@ -1825,6 +1843,9 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell): self._period_init_min = period_init_min self._period_init_max = period_init_max self._reuse = reuse + self._linear1 = None + self._linear2 = None + self._linear3 = None @property def state_size(self): @@ -1872,14 +1893,18 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell): in_mask_gates.append(c_prev) with vs.variable_scope("mask_gates"): + if self._linear1 is None: + self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True) + mask_gates = math_ops.sigmoid( - _linear(in_mask_gates, 2 * self._num_units, True)) + self._linear1(in_mask_gates)) [input_gate, forget_gate] = array_ops.split( axis=1, num_or_size_splits=2, value=mask_gates) with vs.variable_scope("new_input"): - new_input = math_ops.tanh( - _linear([x, h_prev], self._num_units, True)) + if self._linear2 is None: + self._linear2 = _Linear([x, h_prev], self._num_units, True) + new_input = math_ops.tanh(self._linear2([x, h_prev])) new_c = (c_prev * forget_gate + input_gate * new_input) @@ -1888,8 +1913,9 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell): in_out_gate.append(new_c) with vs.variable_scope("output_gate"): - output_gate = math_ops.sigmoid( - _linear(in_out_gate, self._num_units, True)) + if self._linear3 is None: + self._linear3 = _Linear(in_out_gate, self._num_units, True) + output_gate = math_ops.sigmoid(self._linear3(in_out_gate)) new_h = math_ops.tanh(new_c) * output_gate @@ -2056,9 +2082,11 @@ def _conv(args, shape_length = len(shapes[0]) for shape in shapes: if len(shape) not in [3,4,5]: - raise ValueError("Conv Linear expects 3D, 4D or 5D arguments: %s" % str(shapes)) + raise ValueError("Conv Linear expects 3D, 4D " + "or 5D arguments: %s" % str(shapes)) if len(shape) != len(shapes[0]): - raise ValueError("Conv Linear expects all args to be of same Dimensiton: %s" % str(shapes)) + raise ValueError("Conv Linear expects all args " + "to be of same Dimension: %s" % str(shapes)) else: total_arg_size_depth += shape[-1] dtype = [a.dtype for a in args][0] @@ -2076,7 +2104,7 @@ def _conv(args, # Now the computation. kernel = vs.get_variable( - "kernel", + "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype) if len(args) == 1: @@ -2159,6 +2187,8 @@ class GLSTMCell(rnn_cell_impl.RNNCell): else: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) self._output_size = num_units + self._linear1 = None + self._linear2 = None @property def state_size(self): @@ -2227,7 +2257,9 @@ class GLSTMCell(rnn_cell_impl.RNNCell): self._group_shape[0]), self._get_input_for_group(m_prev, group_id, self._group_shape[0])], axis=1) - R_k = _linear(x_g_id, 4 * self._group_shape[1], bias=False) + if self._linear1 is None: + self._linear1 = _Linear(x_g_id, 4 * self._group_shape[1], False) + R_k = self._linear1(x_g_id) # pylint: disable=invalid-name i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1) i_parts.append(i_k) @@ -2267,7 +2299,9 @@ class GLSTMCell(rnn_cell_impl.RNNCell): if self._num_proj is not None: with vs.variable_scope("projection"): - m = _linear(m, self._num_proj, bias=False) + if self._linear2 is None: + self._linear2 = _Linear(m, self._num_proj, False) + m = self._linear2(m) new_state = rnn_cell_impl.LSTMStateTuple(c, m) return m, new_state diff --git a/tensorflow/contrib/s3/BUILD b/tensorflow/contrib/s3/BUILD index a4daed01e724e85741660324f89c7d4f3e98d5d4..b7bc1a11d6583787e2c0fb07d004dc2badc5bcca 100644 --- a/tensorflow/contrib/s3/BUILD +++ b/tensorflow/contrib/s3/BUILD @@ -9,6 +9,7 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", + "tf_cc_binary", "tf_cc_test", ) @@ -24,7 +25,7 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) -cc_binary( +tf_cc_binary( name = "s3_file_system.so", srcs = [ "s3_crypto.cc", diff --git a/tensorflow/contrib/s3/s3_crypto.cc b/tensorflow/contrib/s3/s3_crypto.cc index 1450384dc0f8b4d4f30c8776f6c1e31b0affeea7..bbd66371e41c5ecf4c6edfcb3a115cae2fb4e933 100644 --- a/tensorflow/contrib/s3/s3_crypto.cc +++ b/tensorflow/contrib/s3/s3_crypto.cc @@ -71,7 +71,7 @@ class S3Sha256OpenSSLImpl : public Aws::Utils::Crypto::Hash { SHA256_Init(&sha256); auto currentPos = stream.tellg(); - if (currentPos == -1) { + if (currentPos == std::streampos(std::streamoff(-1))) { currentPos = 0; stream.clear(); } diff --git a/tensorflow/contrib/s3/s3_file_system.cc b/tensorflow/contrib/s3/s3_file_system.cc index b09cf81d469e5c204412401a15b37c435c0ef816..daced83145353c52ae19e2b7e8491b5fcb31cc1f 100644 --- a/tensorflow/contrib/s3/s3_file_system.cc +++ b/tensorflow/contrib/s3/s3_file_system.cc @@ -222,7 +222,6 @@ class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { S3FileSystem::S3FileSystem() { Aws::SDKOptions options; - options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Info; options.cryptoOptions.sha256Factory_create_fn = []() { return Aws::MakeShared(S3CryptoAllocationTag); }; @@ -234,7 +233,6 @@ S3FileSystem::S3FileSystem() { S3FileSystem::~S3FileSystem() { Aws::SDKOptions options; - options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Info; Aws::ShutdownAPI(options); } diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 2caeb9eb614382c815984391df87a70516f519b2..8d4ec4b4dbe29a36ae3fb89f33b8dea21df1f56c 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -80,8 +80,7 @@ class TestEosMasking(test.TestCase): ]) eos_token = 0 - previously_finished = constant_op.constant( - [[0, 1, 0], [0, 1, 1]], dtype=dtypes.float32) + previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool) masked = beam_search_decoder._mask_probs(probs, eos_token, previously_finished) diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 9d67d5a0e0b481af69d8ebff56d81311c575fc63..839df079ee743c67b3eb6180bbf419f07ecb5435 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -342,7 +342,7 @@ class LuongAttention(_BaseAttentionMechanism): num_units: The depth of the attention mechanism. memory: The memory to query; usually the output of an RNN encoder. This tensor should be shaped `[batch_size, max_time, ...]`. - memory_sequence_length (optional): Sequence lengths for the batch entries + memory_sequence_length: (optional) Sequence lengths for the batch entries in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. scale: Python boolean. Whether to scale the energy term. @@ -350,7 +350,7 @@ class LuongAttention(_BaseAttentionMechanism): probabilities. The default is @{tf.nn.softmax}. Other options include @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. Its signature should be: `probabilities = probability_fn(score)`. - score_mask_value: (optional): The mask value for score before passing into + score_mask_value: (optional) The mask value for score before passing into `probability_fn`. The default is -inf. Only used if `memory_sequence_length` is not None. name: Name to use when creating ops. diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index e22912ac5c9e378587d092ae2bed56929fe2a8e7..112ac57a1bc82f9c64981a92641140fa03a30f04 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -20,9 +20,10 @@ from __future__ import print_function import collections +import numpy as np + from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -390,17 +391,17 @@ class BeamSearchDecoder(decoder.Decoder): We do this so that we can use nest and not run into problems with shapes. Args: - t: Tensor of dimension [batch_size*beam_width, s] - s: Tensor, Python int, or TensorShape. + t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`. + s: `Tensor`, Python int, or `TensorShape`. Returns: - Either a reshaped version of t with dimension - [batch_size, beam_width, s] if t's first dimension is of size - batch_size*beam_width or t if not. + If `t` is a matrix or higher order tensor, then the return value is + `t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is + returned unchanged. Raises: - TypeError: If t is an instance of TensorArray. - ValueError: If the rank of t is not statically known. + TypeError: If `t` is an instance of `TensorArray`. + ValueError: If the rank of `t` is not statically known. """ _check_maybe(t) if t.shape.ndims >= 1: @@ -411,19 +412,19 @@ class BeamSearchDecoder(decoder.Decoder): def _maybe_merge_batch_beams(self, t, s): """Splits the tensor from a batch by beams into a batch of beams. - More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We - reshape this into [batch_size, beam_width, s] + More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`, + then we reshape it to `[batch_size, beam_width] + s`. Args: - t: Tensor of dimension [batch_size*beam_width, s] - s: Tensor, Python int, or TensorShape. + t: `Tensor` of dimension `[batch_size * beam_width] + s`. + s: `Tensor`, Python int, or `TensorShape`. Returns: - A reshaped version of t with dimension [batch_size, beam_width, s]. + A reshaped version of t with shape `[batch_size, beam_width] + s`. Raises: - TypeError: If t is an instance of TensorArray. - ValueError: If the rank of t is not statically known. + TypeError: If `t` is an instance of `TensorArray`. + ValueError: If the rank of `t` is not statically known. """ _check_maybe(t) if t.shape.ndims >= 2: @@ -521,14 +522,12 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, # Calculate the continuation lengths by adding to all continuing beams. vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1] lengths_to_add = array_ops.one_hot( - indices=array_ops.tile( - array_ops.reshape(end_token, [1, 1]), [batch_size, beam_width]), + indices=array_ops.fill([batch_size, beam_width], end_token), depth=vocab_size, - on_value=constant_op.constant(0, dtype=dtypes.int64), - off_value=constant_op.constant(1, dtype=dtypes.int64), + on_value=np.int64(0), off_value=np.int64(1), dtype=dtypes.int64) - add_mask = (1 - math_ops.to_int64(previously_finished)) - lengths_to_add = array_ops.expand_dims(add_mask, 2) * lengths_to_add + add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished)) + lengths_to_add *= array_ops.expand_dims(add_mask, 2) new_prediction_lengths = ( lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) @@ -592,9 +591,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, # 1. Finished beams remain unchanged # 2. Beams that are now finished (EOS predicted) remain unchanged # 3. Beams that are not yet finished have their length increased by 1 - lengths_to_add = math_ops.to_int64( - math_ops.not_equal(next_word_ids, end_token)) - lengths_to_add = (1 - math_ops.to_int64(next_finished)) * lengths_to_add + lengths_to_add = math_ops.to_int64(math_ops.logical_not(next_finished)) next_prediction_len = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=beam_state.lengths, @@ -652,13 +649,20 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight): def _length_penalty(sequence_lengths, penalty_factor): """Calculates the length penalty. See https://arxiv.org/abs/1609.08144. + Returns the length penalty tensor: + ``` + [(5+sequence_lengths)/6]**penalty_factor + ``` + where all operations are performed element-wise. + Args: - sequence_lengths: The sequence length of all hypotheses, a tensor - of shape [beam_size, vocab_size]. + sequence_lengths: `Tensor`, the sequence lengths of each hypotheses. penalty_factor: A scalar that weights the length penalty. Returns: - The length penalty factor, a tensor fo shape [beam_size]. + If the penalty is `0`, returns the scalar `1.0`. Otherwise returns + the length penalty factor, a tensor with the same shape as + `sequence_lengths`. """ penalty_factor = ops.convert_to_tensor(penalty_factor, name="penalty_factor") penalty_factor.set_shape(()) # penalty should be a scalar. @@ -680,8 +684,7 @@ def _mask_probs(probs, eos_token, finished): eos_token: An int32 id corresponding to the EOS token to allocate probability to. finished: A boolean tensor of shape `[batch_size, beam_width]` that - specifies which - elements in the beam are finished already. + specifies which elements in the beam are finished already. Returns: A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished @@ -689,10 +692,12 @@ def _mask_probs(probs, eos_token, finished): probability on the EOS token. """ vocab_size = array_ops.shape(probs)[2] - finished_mask = array_ops.expand_dims( - math_ops.to_float(1. - math_ops.to_float(finished)), 2) + finished_mask = math_ops.cast(array_ops.expand_dims(finished, 2), probs.dtype) + not_finished_mask = math_ops.cast( + array_ops.expand_dims(math_ops.logical_not(finished), 2), + probs.dtype) # These examples are not finished and we leave them - non_finished_examples = finished_mask * probs + non_finished_examples = not_finished_mask * probs # All finished examples are replaced with a vector that has all # probability on EOS finished_row = array_ops.one_hot( @@ -701,7 +706,7 @@ def _mask_probs(probs, eos_token, finished): dtype=probs.dtype, on_value=0., off_value=probs.dtype.min) - finished_examples = (1. - finished_mask) * finished_row + finished_examples = finished_mask * finished_row return finished_examples + non_finished_examples diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 8c11cf0d6450b5ea0f1d1af21c24a66c629cce90..2204b684ac993cd82e69b3fd74801bff610b5fd4 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -5,12 +5,14 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "cuda_py_tests") +load("//tensorflow:tensorflow.bzl", "py_test") # @unused py_library( name = "signal_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), srcs_version = "PY2AND3", deps = [ + ":test_util", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", @@ -24,16 +26,42 @@ py_library( ], ) +py_library( + name = "test_util", + srcs = ["python/kernel_tests/test_util.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework", + "//tensorflow/python:tf_optimizer", + ], +) + cuda_py_tests( name = "mel_ops_test", srcs = ["python/kernel_tests/mel_ops_test.py"], additional_deps = [ ":signal_py", + ":test_util", "//third_party/py/numpy", "//tensorflow/python:client_testlib", ], ) +cuda_py_tests( + name = "mfcc_ops_test", + srcs = ["python/kernel_tests/mfcc_ops_test.py"], + additional_deps = [ + ":signal_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:spectral_ops_test_util", + ], +) + cuda_py_tests( name = "reconstruction_ops_test", srcs = ["python/kernel_tests/reconstruction_ops_test.py"], @@ -56,6 +84,7 @@ cuda_py_tests( srcs = ["python/kernel_tests/shape_ops_test.py"], additional_deps = [ ":signal_py", + ":test_util", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:math_ops", @@ -93,6 +122,7 @@ cuda_py_tests( srcs = ["python/kernel_tests/window_ops_test.py"], additional_deps = [ ":signal_py", + ":test_util", "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py index 25123b097e380a7590ea7377d6c979e449ec96b0..0f2592b0b05722145f1b323ada52fa53e6cdc4ba 100644 --- a/tensorflow/contrib/signal/__init__.py +++ b/tensorflow/contrib/signal/__init__.py @@ -20,6 +20,7 @@ See the @{$python/contrib.signal} guide. @@hamming_window @@hann_window @@inverse_stft +@@mfccs_from_log_mel_spectrograms @@linear_to_mel_weight_matrix @@overlap_and_add @@stft @@ -27,6 +28,7 @@ See the @{$python/contrib.signal} guide. [hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window [mel]: https://en.wikipedia.org/wiki/Mel_scale +[mfcc]: https://en.wikipedia.org/wiki/Mel-frequency_cepstrum [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ @@ -35,6 +37,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.signal.python.ops.mel_ops import linear_to_mel_weight_matrix +from tensorflow.contrib.signal.python.ops.mfcc_ops import mfccs_from_log_mel_spectrograms from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add from tensorflow.contrib.signal.python.ops.shape_ops import frame # `frame` used to be named `frames`, which is a noun and not a verb. diff --git a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py index f107b53f01ca5422a57c6b03f6ec385d937bfead..b861476b67fc360f383465145ccd1cc620de5a99 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py @@ -20,8 +20,10 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.signal.python.kernel_tests import test_util from tensorflow.contrib.signal.python.ops import mel_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.platform import test # mel spectrum constants and functions. @@ -159,6 +161,15 @@ class LinearToMelTest(test.TestCase): with self.assertRaises(ValueError): mel_ops.linear_to_mel_weight_matrix(dtype=dtypes.int32) + def test_constant_folding(self): + """Mel functions should be constant foldable.""" + for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): + g = ops.Graph() + with g.as_default(): + mel_matrix = mel_ops.linear_to_mel_weight_matrix(dtype=dtype) + rewritten_graph = test_util.grappler_optimize(g, [mel_matrix]) + self.assertEqual(1, len(rewritten_graph.node)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c04f1cf5bad358a14a1827df05a129339502c86f --- /dev/null +++ b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for mfcc_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.signal.python.ops import mfcc_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import spectral_ops_test_util +from tensorflow.python.platform import test + + +# TODO(rjryan): We have no open source tests for MFCCs at the moment. Internally +# at Google, this code is tested against a reference implementation that follows +# HTK conventions. +class MFCCTest(test.TestCase): + + def test_error(self): + # num_mel_bins must be positive. + with self.assertRaises(ValueError): + signal = array_ops.zeros((2, 3, 0)) + mfcc_ops.mfccs_from_log_mel_spectrograms(signal) + + # signal must be float32 + with self.assertRaises(ValueError): + signal = array_ops.zeros((2, 3, 5), dtype=dtypes.float64) + mfcc_ops.mfccs_from_log_mel_spectrograms(signal) + + def test_basic(self): + """A basic test that the op runs on random input.""" + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(use_gpu=True): + signal = random_ops.random_normal((2, 3, 5)) + mfcc_ops.mfccs_from_log_mel_spectrograms(signal).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py index 8633ced599f137da08a4181ec9cbf4b48517199d..1c052354b8afcc5fd8a53b783cc5c676588cf48c 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py @@ -20,9 +20,11 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.signal.python.kernel_tests import test_util from tensorflow.contrib.signal.python.ops import shape_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -334,5 +336,19 @@ class FrameTest(test.TestCase): signal, signal_shape, frames, frames.shape.as_list()) self.assertLess(error, 2e-5) + def test_constant_folding(self): + """frame should be constant foldable for constant inputs.""" + for pad_end in [False, True]: + g = ops.Graph() + with g.as_default(): + frame_length, frame_step = 32, 16 + signal_shape = (2, 128) + signal = array_ops.ones(signal_shape) + frames = shape_ops.frame(signal, frame_length, frame_step, + pad_end=pad_end) + rewritten_graph = test_util.grappler_optimize(g, [frames]) + self.assertEqual(1, len(rewritten_graph.node)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py index 305a2b2eb9858b381988335caa5cc6b2e11e2bac..72d317dc418d313c1c59ac12019a0eee48261fe4 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py @@ -59,7 +59,11 @@ class SpectralOpsTest(test.TestCase): @staticmethod def _np_inverse_stft(stft, fft_length, hop_length, window_length): - frames = np.fft.irfft(stft, fft_length)[..., :window_length] + frames = np.fft.irfft(stft, fft_length) + # Pad or truncate frames's inner dimension to window_length. + frames = frames[..., :window_length] + frames = np.pad(frames, [[0, 0]] * (frames.ndim - 1) + + [[0, max(0, window_length - frames.shape[-1])]], "constant") window = SpectralOpsTest._np_hann_periodic_window(window_length) return SpectralOpsTest._np_overlap_add(frames * window, hop_length) @@ -79,12 +83,27 @@ class SpectralOpsTest(test.TestCase): self.test_session(use_gpu=True)) as sess: actual_stft = spectral_ops.stft( signal, frame_length, frame_step, fft_length, pad_end=False) + signal_ph = array_ops.placeholder(dtype=dtypes.as_dtype(signal.dtype)) + actual_stft_from_ph = spectral_ops.stft( + signal_ph, frame_length, frame_step, fft_length, pad_end=False) actual_inverse_stft = spectral_ops.inverse_stft( actual_stft, frame_length, frame_step, fft_length) - actual_stft, actual_inverse_stft = sess.run( - [actual_stft, actual_inverse_stft]) + actual_stft, actual_stft_from_ph, actual_inverse_stft = sess.run( + [actual_stft, actual_stft_from_ph, actual_inverse_stft], + feed_dict={signal_ph: signal}) + + actual_stft_ph = array_ops.placeholder(dtype=actual_stft.dtype) + actual_inverse_stft_from_ph = sess.run( + spectral_ops.inverse_stft( + actual_stft_ph, frame_length, frame_step, fft_length), + feed_dict={actual_stft_ph: actual_stft}) + + # Confirm that there is no difference in output when shape/rank is fully + # unknown or known. + self.assertAllClose(actual_stft, actual_stft_from_ph) + self.assertAllClose(actual_inverse_stft, actual_inverse_stft_from_ph) expected_stft = SpectralOpsTest._np_stft( signal, fft_length, frame_step, frame_length) @@ -142,6 +161,11 @@ class SpectralOpsTest(test.TestCase): self.assertAllEqual([64, 9], stft.shape.as_list()) self.assertAllEqual([64, 9], stft.eval().shape) + stft = spectral_ops.stft(signal, frame_length=16, frame_step=8, + fft_length=8, pad_end=True) + self.assertAllEqual([64, 5], stft.shape.as_list()) + self.assertAllEqual([64, 5], stft.eval().shape) + stft = np.zeros((32, 9)).astype(np.complex64) inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8, @@ -156,6 +180,7 @@ class SpectralOpsTest(test.TestCase): test_configs = [ (512, 64, 32, 64), (512, 64, 64, 64), + (512, 72, 64, 64), (512, 64, 25, 64), (512, 25, 15, 36), (123, 23, 5, 42), diff --git a/tensorflow/contrib/signal/python/kernel_tests/test_util.py b/tensorflow/contrib/signal/python/kernel_tests/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3603b6a97ef7c3a4b940b83281ebceda93c9db --- /dev/null +++ b/tensorflow/contrib/signal/python/kernel_tests/test_util.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utilities for tf.contrib.signal.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.training import saver + + +def grappler_optimize(graph, fetches=None, rewriter_config=None): + """Tries to optimize the provided graph using grappler. + + Args: + graph: A @{tf.Graph} instance containing the graph to optimize. + fetches: An optional list of `Tensor`s to fetch (i.e. not optimize away). + Grappler uses the 'train_op' collection to look for fetches, so if not + provided this collection should be non-empty. + rewriter_config: An optional @{tf.RewriterConfig} to use when rewriting the + graph. + + Returns: + A @{tf.GraphDef} containing the rewritten graph. + """ + if rewriter_config is None: + rewriter_config = rewriter_config_pb2.RewriterConfig() + if fetches is not None: + for fetch in fetches: + graph.add_to_collection('train_op', fetch) + metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def()) + return tf_optimizer.OptimizeGraph(rewriter_config, metagraph) diff --git a/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py index c3e0464596244b331906dab47cee349c1ea737b5..5a464699dac5a737e0c6e0122a4a6699e945f695 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py @@ -22,8 +22,10 @@ import functools import numpy as np +from tensorflow.contrib.signal.python.kernel_tests import test_util from tensorflow.contrib.signal.python.ops import window_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -91,6 +93,17 @@ class WindowOpsTest(test.TestCase): functools.partial(_scipy_raised_cosine, a=0.54, b=0.46), window_ops.hamming_window) + def test_constant_folding(self): + """Window functions should be constant foldable for constant inputs.""" + for window_fn in (window_ops.hann_window, window_ops.hamming_window): + for dtype, _ in self._dtypes: + for periodic in [False, True]: + g = ops.Graph() + with g.as_default(): + window = window_fn(100, periodic=periodic, dtype=dtype) + rewritten_graph = test_util.grappler_optimize(g, [window]) + self.assertEqual(1, len(rewritten_graph.node)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/signal/python/ops/mfcc_ops.py b/tensorflow/contrib/signal/python/ops/mfcc_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7bc7b57cd4f1033a8bda0845ccd8e777e0213d6b --- /dev/null +++ b/tensorflow/contrib/signal/python/ops/mfcc_ops.py @@ -0,0 +1,108 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mel-Frequency Cepstral Coefficients (MFCCs) ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import spectral_ops + + +def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None): + """Computes [MFCCs][mfcc] of `log_mel_spectrograms`. + + Implemented with GPU-compatible ops and supports gradients. + + [Mel-Frequency Cepstral Coefficient (MFCC)][mfcc] calculation consists of + taking the DCT-II of a log-magnitude mel-scale spectrogram. [HTK][htk]'s MFCCs + use a particular scaling of the DCT-II which is almost orthogonal + normalization. We follow this convention. + + All `num_mel_bins` MFCCs are returned and it is up to the caller to select + a subset of the MFCCs based on their application. For example, it is typical + to only use the first few for speech recognition, as this results in + an approximately pitch-invariant representation of the signal. + + For example: + + ```python + sample_rate = 16000.0 + # A Tensor of [batch_size, num_samples] mono PCM samples in the range [-1, 1]. + pcm = tf.placeholder(tf.float32, [None, None]) + + # A 1024-point STFT with frames of 64 ms and 75% overlap. + stfts = tf.contrib.signal.stft(pcm, frame_length=1024, frame_step=256, + fft_length=1024) + spectrograms = tf.abs(stft) + + # Warp the linear scale spectrograms into the mel-scale. + num_spectrogram_bins = stfts.shape[-1].value + lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 80 + linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix( + num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, + upper_edge_hertz) + mel_spectrograms = tf.tensordot( + spectrograms, linear_to_mel_weight_matrix, 1) + mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate( + linear_to_mel_weight_matrix.shape[-1:])) + + # Compute a stabilized log to get log-magnitude mel-scale spectrograms. + log_mel_spectrograms = tf.log(mel_spectrograms + 1e-6) + + # Compute MFCCs from log_mel_spectrograms and take the first 13. + mfccs = tf.contrib.signal.mfccs_from_log_mel_spectrograms( + log_mel_spectrograms)[..., :13] + ``` + + Args: + log_mel_spectrograms: A `[..., num_mel_bins]` `float32` `Tensor` of + log-magnitude mel-scale spectrograms. + name: An optional name for the operation. + Returns: + A `[..., num_mel_bins]` `float32` `Tensor` of the MFCCs of + `log_mel_spectrograms`. + + Raises: + ValueError: If `num_mel_bins` is not positive. + + [mfcc]: https://en.wikipedia.org/wiki/Mel-frequency_cepstrum + [htk]: https://en.wikipedia.org/wiki/HTK_(software) + """ + with ops.name_scope(name, 'mfccs_from_log_mel_spectrograms', + [log_mel_spectrograms]): + # Compute the DCT-II of the resulting log-magnitude mel-scale spectrogram. + # The DCT used in HTK scales every basis vector by sqrt(2/N), which is the + # scaling required for an "orthogonal" DCT-II *except* in the 0th bin, where + # the true orthogonal DCT (as implemented by scipy) scales by sqrt(1/N). For + # this reason, we don't apply orthogonal normalization and scale the DCT by + # `0.5 * sqrt(2/N)` manually. + log_mel_spectrograms = ops.convert_to_tensor(log_mel_spectrograms, + dtype=dtypes.float32) + if (log_mel_spectrograms.shape.ndims and + log_mel_spectrograms.shape[-1].value is not None): + num_mel_bins = log_mel_spectrograms.shape[-1].value + if num_mel_bins == 0: + raise ValueError('num_mel_bins must be positive. Got: %s' % + log_mel_spectrograms) + else: + num_mel_bins = array_ops.shape(log_mel_spectrograms)[-1] + + dct2 = spectral_ops.dct(log_mel_spectrograms) + return dct2 * math_ops.rsqrt(num_mel_bins * 2.0) diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py index 950d8f471c6b34ecd7488b4434776a333d2fa782..5ed109b7ddad126d16cf45c631434ba0a674896b 100644 --- a/tensorflow/contrib/signal/python/ops/spectral_ops.py +++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py @@ -28,6 +28,7 @@ from tensorflow.contrib.signal.python.ops import window_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import spectral_ops @@ -59,8 +60,7 @@ def stft(signals, frame_length, frame_step, fft_length=None, Raises: ValueError: If `signals` is not at least rank 1, `frame_length` is - not scalar, `frame_step` is not scalar, or `frame_length` - is greater than `fft_length`. + not scalar, or `frame_step` is not scalar. [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ @@ -78,15 +78,6 @@ def stft(signals, frame_length, frame_step, fft_length=None, else: fft_length = ops.convert_to_tensor(fft_length, name='fft_length') - frame_length_static = tensor_util.constant_value( - frame_length) - fft_length_static = tensor_util.constant_value(fft_length) - if (frame_length_static is not None and fft_length_static is not None and - frame_length_static > fft_length_static): - raise ValueError('frame_length (%d) may not be larger than ' - 'fft_length (%d)' % (frame_length_static, - fft_length_static)) - framed_signals = shape_ops.frame( signals, frame_length, frame_step, pad_end=pad_end) @@ -131,8 +122,7 @@ def inverse_stft(stfts, Raises: ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar, - `frame_step` is not scalar, or `fft_length` is not scalar, or - `frame_length` is greater than `fft_length`. + `frame_step` is not scalar, or `fft_length` is not scalar. [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ @@ -149,16 +139,40 @@ def inverse_stft(stfts, fft_length = ops.convert_to_tensor(fft_length, name='fft_length') fft_length.shape.assert_has_rank(0) - frame_length_static = tensor_util.constant_value( - frame_length) - fft_length_static = tensor_util.constant_value(fft_length) - if (frame_length_static is not None and fft_length_static is not None and - frame_length_static > fft_length_static): - raise ValueError('frame_length (%d) may not be larger than ' - 'fft_length (%d)' % (frame_length_static, - fft_length_static)) - - real_frames = spectral_ops.irfft(stfts, [fft_length])[..., :frame_length] + real_frames = spectral_ops.irfft(stfts, [fft_length]) + + # frame_length may be larger or smaller than fft_length, so we pad or + # truncate real_frames to frame_length. + frame_length_static = tensor_util.constant_value(frame_length) + # If we don't know the shape of real_frames's inner dimension, pad and + # truncate to frame_length. + if (frame_length_static is None or + real_frames.shape.ndims is None or + real_frames.shape[-1].value is None): + real_frames = real_frames[..., :frame_length] + real_frames_rank = array_ops.rank(real_frames) + real_frames_shape = array_ops.shape(real_frames) + paddings = array_ops.concat( + [array_ops.zeros([real_frames_rank - 1, 2], + dtype=frame_length.dtype), + [[0, math_ops.maximum(0, frame_length - real_frames_shape[-1])]]], 0) + real_frames = array_ops.pad(real_frames, paddings) + # We know real_frames's last dimension and frame_length statically. If they + # are different, then pad or truncate real_frames to frame_length. + elif real_frames.shape[-1].value > frame_length_static: + real_frames = real_frames[..., :frame_length_static] + elif real_frames.shape[-1].value < frame_length_static: + pad_amount = frame_length_static - real_frames.shape[-1].value + real_frames = array_ops.pad(real_frames, + [[0, 0]] * (real_frames.shape.ndims - 1) + + [[0, pad_amount]]) + + # The above code pads the inner dimension of real_frames to frame_length, + # but it does so in a way that may not be shape-inference friendly. + # Restore shape information if we are able to. + if frame_length_static is not None and real_frames.shape.ndims is not None: + real_frames.set_shape([None] * (real_frames.shape.ndims - 1) + + [frame_length_static]) # Optionally window and overlap-add the inner 2 dimensions of real_frames # into a single [samples] dimension. diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md index c0aa6d445acfc99ef9da9a54fc269babee754951..0bfd0801d55b25f78cee60e87ee6c43f11a4995c 100644 --- a/tensorflow/contrib/slim/README.md +++ b/tensorflow/contrib/slim/README.md @@ -574,7 +574,7 @@ with tf.Graph().as_default(): images, labels = ... # Define the model: - predictions = vgg.vgg16(images, is_training=True) + predictions = vgg.vgg_16(images, is_training=True) # Specify the loss function: slim.losses.softmax_cross_entropy(predictions, labels) diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index f9449095be0cb53a8c762eaa70f005f01645743d..7a56df9e97c21ec353ba3e610d30c5c434d665da 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -135,7 +135,10 @@ class BoundingBox(ItemHandler): """ sides = [] for key in self._full_keys: - side = array_ops.expand_dims(keys_to_tensors[key].values, 0) + side = keys_to_tensors[key] + if isinstance(side, sparse_tensor.SparseTensor): + side = side.values + side = array_ops.expand_dims(side, 0) sides.append(side) bounding_box = array_ops.concat(sides, 0) @@ -204,6 +207,42 @@ class Tensor(ItemHandler): return tensor +class LookupTensor(Tensor): + """An ItemHandler that returns a parsed Tensor, the result of a lookup.""" + + def __init__(self, + tensor_key, + table, + shape_keys=None, + shape=None, + default_value=''): + """Initializes the LookupTensor handler. + + See Tensor. Simply calls a vocabulary (most often, a label mapping) lookup. + + Args: + tensor_key: the name of the `TFExample` feature to read the tensor from. + table: A tf.lookup table. + shape_keys: Optional name or list of names of the TF-Example feature in + which the tensor shape is stored. If a list, then each corresponds to + one dimension of the shape. + shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is + reshaped accordingly. + default_value: The value used when the `tensor_key` is not found in a + particular `TFExample`. + + Raises: + ValueError: if both `shape_keys` and `shape` are specified. + """ + self._table = table + super(LookupTensor, self).__init__(tensor_key, shape_keys, shape, + default_value) + + def tensors_to_item(self, keys_to_tensors): + unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors) + return self._table.lookup(unmapped_tensor) + + class SparseTensor(ItemHandler): """An ItemHandler for SparseTensors.""" diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py index 96606b9c0e5b19a360f45ffe9922874cabe621e8..9c5a14d0060634a09fa1abf1d63aaf63eca7d096 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import image_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test @@ -692,7 +693,7 @@ class TFExampleDecoderTest(test.TestCase): else: self.assertAllClose(image, decoded_image, atol=0) - def testDecodeExampleWithBoundingBox(self): + def testDecodeExampleWithBoundingBoxSparse(self): num_bboxes = 10 np_ymin = np.random.rand(num_bboxes, 1) np_xmin = np.random.rand(num_bboxes, 1) @@ -731,6 +732,49 @@ class TFExampleDecoderTest(test.TestCase): self.assertAllClose(np_bboxes, bboxes) + def testDecodeExampleWithBoundingBoxDense(self): + num_bboxes = 10 + np_ymin = np.random.rand(num_bboxes, 1) + np_xmin = np.random.rand(num_bboxes, 1) + np_ymax = np.random.rand(num_bboxes, 1) + np_xmax = np.random.rand(num_bboxes, 1) + np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax]) + + example = example_pb2.Example(features=feature_pb2.Features(feature={ + 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin), + 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin), + 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax), + 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax), + })) + serialized_example = example.SerializeToString() + + with self.test_session(): + serialized_example = array_ops.reshape(serialized_example, shape=[]) + + keys_to_features = { + 'image/object/bbox/ymin': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), + 'image/object/bbox/xmin': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), + 'image/object/bbox/ymax': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), + 'image/object/bbox/xmax': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), + } + + items_to_handlers = { + 'object/bbox': + tfexample_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'], + 'image/object/bbox/'), + } + + decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, + items_to_handlers) + [tf_bboxes] = decoder.decode(serialized_example, ['object/bbox']) + bboxes = tf_bboxes.eval() + + self.assertAllClose(np_bboxes, bboxes) + def testDecodeExampleWithRepeatedImages(self): image_shape = (2, 3, 3) image_format = 'png' @@ -768,6 +812,36 @@ class TFExampleDecoderTest(test.TestCase): self.assertAllEqual(np.squeeze(output_image[0, :, :, :]), image) self.assertAllEqual(np.squeeze(output_image[1, :, :, :]), image) + def testDecodeExampleWithLookup(self): + + example = example_pb2.Example(features=feature_pb2.Features(feature={ + 'image/object/class/text': self._BytesFeature( + np.array(['cat', 'dog', 'guinea pig'])), + })) + serialized_example = example.SerializeToString() + # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 + table = lookup_ops.index_table_from_tensor( + constant_op.constant(['dog', 'guinea pig', 'cat'])) + + with self.test_session() as sess: + sess.run(lookup_ops.tables_initializer()) + + serialized_example = array_ops.reshape(serialized_example, shape=[]) + + keys_to_features = { + 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), + } + + items_to_handlers = { + 'labels': + tfexample_decoder.LookupTensor('image/object/class/text', table), + } + + decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, + items_to_handlers) + obtained_class_ids = decoder.decode(serialized_example)[0].eval() + + self.assertAllClose([2, 0, 1], obtained_class_ids) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD index e2035ab014cfd09682257fbbbf3a2868681aa850..7f03aaf085cf26e3f5f940f4388828006a02ef42 100644 --- a/tensorflow/contrib/slim/python/slim/nets/BUILD +++ b/tensorflow/contrib/slim/python/slim/nets/BUILD @@ -287,25 +287,6 @@ py_test( ], ) -py_test( - name = "resnet_is_training_test", - size = "medium", - srcs = ["resnet_is_training_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":resnet_utils", - ":resnet_v1", - ":resnet_v2", - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", - "//third_party/py/numpy", - ], -) - py_library( name = "vgg", srcs = ["vgg.py"], diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py deleted file mode 100644 index 9a165577b699f757057aa10cc14bc1d48c02343a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Specifying is_training in resnet_arg_scope is being deprecated. - -Test that everything behaves as expected in the meantime. - -Note: This test modifies the layers.batch_norm function. -Other tests that use layers.batch_norm may not work if added to this file. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import add_arg_scope -from tensorflow.contrib.framework.python.ops import arg_scope -from tensorflow.contrib.slim.python.slim.nets import resnet_utils -from tensorflow.contrib.slim.python.slim.nets import resnet_v1 -from tensorflow.contrib.slim.python.slim.nets import resnet_v2 -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import test - - -def create_test_input(batch, height, width, channels): - """Create test input tensor.""" - if None in [batch, height, width, channels]: - return array_ops.placeholder(dtypes.float32, (batch, height, width, - channels)) - else: - return math_ops.to_float( - np.tile( - np.reshape( - np.reshape(np.arange(height), [height, 1]) + - np.reshape(np.arange(width), [1, width]), - [1, height, width, 1]), - [batch, 1, 1, channels])) - - -class ResnetIsTrainingTest(test.TestCase): - - def _testDeprecatingIsTraining(self, network_fn): - batch_norm_fn = layers.batch_norm - - @add_arg_scope - def batch_norm_expect_is_training(*args, **kwargs): - assert kwargs['is_training'] - return batch_norm_fn(*args, **kwargs) - - @add_arg_scope - def batch_norm_expect_is_not_training(*args, **kwargs): - assert not kwargs['is_training'] - return batch_norm_fn(*args, **kwargs) - - global_pool = True - num_classes = 10 - inputs = create_test_input(2, 224, 224, 3) - - # Default argument for resnet_arg_scope - layers.batch_norm = batch_norm_expect_is_training - with arg_scope(resnet_utils.resnet_arg_scope()): - network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet1') - - layers.batch_norm = batch_norm_expect_is_training - with arg_scope(resnet_utils.resnet_arg_scope()): - network_fn( - inputs, - num_classes, - is_training=True, - global_pool=global_pool, - scope='resnet2') - - layers.batch_norm = batch_norm_expect_is_not_training - with arg_scope(resnet_utils.resnet_arg_scope()): - network_fn( - inputs, - num_classes, - is_training=False, - global_pool=global_pool, - scope='resnet3') - - # resnet_arg_scope with is_training set to True (deprecated) - layers.batch_norm = batch_norm_expect_is_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)): - network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet4') - - layers.batch_norm = batch_norm_expect_is_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)): - network_fn( - inputs, - num_classes, - is_training=True, - global_pool=global_pool, - scope='resnet5') - - layers.batch_norm = batch_norm_expect_is_not_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)): - network_fn( - inputs, - num_classes, - is_training=False, - global_pool=global_pool, - scope='resnet6') - - # resnet_arg_scope with is_training set to False (deprecated) - layers.batch_norm = batch_norm_expect_is_not_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): - network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet7') - - layers.batch_norm = batch_norm_expect_is_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): - network_fn( - inputs, - num_classes, - is_training=True, - global_pool=global_pool, - scope='resnet8') - - layers.batch_norm = batch_norm_expect_is_not_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): - network_fn( - inputs, - num_classes, - is_training=False, - global_pool=global_pool, - scope='resnet9') - - layers.batch_norm = batch_norm_fn - - def testDeprecatingIsTrainingResnetV1(self): - self._testDeprecatingIsTraining(resnet_v1.resnet_v1_50) - - def testDeprecatingIsTrainingResnetV2(self): - self._testDeprecatingIsTraining(resnet_v2.resnet_v2_50) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py b/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py index 58614a998abc2a983c4cd8df934cb30090c6443f..cfafee5d8c7a8dd326f6512b9aa224c78ccfb3d4 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py @@ -41,7 +41,6 @@ from __future__ import print_function import collections from tensorflow.contrib import layers as layers_lib -from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework.python.ops import add_arg_scope from tensorflow.contrib.framework.python.ops import arg_scope from tensorflow.contrib.layers.python.layers import initializers @@ -223,12 +222,7 @@ def stack_blocks_dense(net, return net -@deprecated_args( - '2017-08-01', - 'Pass is_training directly to the network instead of the arg_scope.', - 'is_training') -def resnet_arg_scope(is_training=True, - weight_decay=0.0001, +def resnet_arg_scope(weight_decay=0.0001, batch_norm_decay=0.997, batch_norm_epsilon=1e-5, batch_norm_scale=True): @@ -240,8 +234,6 @@ def resnet_arg_scope(is_training=True, training ResNets from scratch, they might need to be tuned. Args: - is_training: Whether or not we are training the parameters in the batch - normalization layers of the model. (deprecated) weight_decay: The weight decay to use for regularizing the model. batch_norm_decay: The moving average decay when estimating layer activation statistics in batch normalization. @@ -254,7 +246,6 @@ def resnet_arg_scope(is_training=True, An `arg_scope` to use for the resnet models. """ batch_norm_params = { - 'is_training': is_training, 'decay': batch_norm_decay, 'epsilon': batch_norm_epsilon, 'scale': batch_norm_scale, @@ -266,7 +257,8 @@ def resnet_arg_scope(is_training=True, weights_regularizer=regularizers.l2_regularizer(weight_decay), weights_initializer=initializers.variance_scaling_initializer(), activation_fn=nn_ops.relu, - normalizer_fn=layers.batch_norm): + normalizer_fn=layers.batch_norm, + normalizer_params=batch_norm_params): with arg_scope([layers.batch_norm], **batch_norm_params): # The following implies padding='SAME' for pool1, which makes feature # alignment easier for dense prediction tasks. This is also used in diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py index 90f93d46e34b7554353d74529360d8e9a8ff5d06..235a595de49f956e1df740fd821936c80eefaa55 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py @@ -128,7 +128,7 @@ def bottleneck(inputs, def resnet_v1(inputs, blocks, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, include_root_block=True, @@ -163,8 +163,7 @@ def resnet_v1(inputs, is a resnet_utils.Block object describing the units in the block. num_classes: Number of predicted classes for classification tasks. If None we return the features before the logit layer. - is_training: whether is training or not. If None, the value inherited from - the resnet_arg_scope is used. Specifying value None is deprecated. + is_training: whether batch_norm layers are in training mode. global_pool: If True, we perform global average pooling before computing the logits. Set to True for image classification, False for dense prediction. output_stride: If None, then the output will be computed at the nominal @@ -196,11 +195,7 @@ def resnet_v1(inputs, with arg_scope( [layers.conv2d, bottleneck, resnet_utils.stack_blocks_dense], outputs_collections=end_points_collection): - if is_training is not None: - bn_scope = arg_scope([layers.batch_norm], is_training=is_training) - else: - bn_scope = arg_scope([]) - with bn_scope: + with arg_scope([layers.batch_norm], is_training=is_training): net = inputs if include_root_block: if output_stride is not None: @@ -255,7 +250,7 @@ def resnet_v1_block(scope, base_depth, num_units, stride): def resnet_v1_50(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, @@ -281,7 +276,7 @@ def resnet_v1_50(inputs, def resnet_v1_101(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, @@ -307,7 +302,7 @@ def resnet_v1_101(inputs, def resnet_v1_152(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, @@ -333,7 +328,7 @@ def resnet_v1_152(inputs, def resnet_v1_200(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py index d510337fef0762e086aee7341d4739393ee165f8..b4fd2580c2b8eaef79c1dd5f2f6b4a18cd0904c7 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py @@ -250,7 +250,7 @@ class ResnetCompleteNetworkTest(test.TestCase): def _resnet_small(self, inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, include_root_block=True, diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py index 63e8f1ff356dfcf0427d5170a03faa47ee06298c..61665c9c8ba7817377a16bf3f2673447cab0518e 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py @@ -130,7 +130,7 @@ def bottleneck(inputs, def resnet_v2(inputs, blocks, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, include_root_block=True, @@ -165,8 +165,7 @@ def resnet_v2(inputs, is a resnet_utils.Block object describing the units in the block. num_classes: Number of predicted classes for classification tasks. If None we return the features before the logit layer. - is_training: whether is training or not. If None, the value inherited from - the resnet_arg_scope is used. Specifying value None is deprecated. + is_training: whether batch_norm layers are in training mode. global_pool: If True, we perform global average pooling before computing the logits. Set to True for image classification, False for dense prediction. output_stride: If None, then the output will be computed at the nominal @@ -200,11 +199,7 @@ def resnet_v2(inputs, with arg_scope( [layers_lib.conv2d, bottleneck, resnet_utils.stack_blocks_dense], outputs_collections=end_points_collection): - if is_training is not None: - bn_scope = arg_scope([layers.batch_norm], is_training=is_training) - else: - bn_scope = arg_scope([]) - with bn_scope: + with arg_scope([layers.batch_norm], is_training=is_training): net = inputs if include_root_block: if output_stride is not None: @@ -268,7 +263,7 @@ def resnet_v2_block(scope, base_depth, num_units, stride): def resnet_v2_50(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, @@ -294,8 +289,8 @@ def resnet_v2_50(inputs, def resnet_v2_101(inputs, num_classes=None, + is_training=True, global_pool=True, - is_training=None, output_stride=None, reuse=None, scope='resnet_v2_101'): @@ -320,7 +315,7 @@ def resnet_v2_101(inputs, def resnet_v2_152(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, @@ -346,7 +341,7 @@ def resnet_v2_152(inputs, def resnet_v2_200(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py index c4f3b071fd940d2c3d7c80fa3041b0426e336ab0..6bdda18c5ba8fe0c9d3374010266c3391044a206 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py @@ -254,7 +254,7 @@ class ResnetCompleteNetworkTest(test.TestCase): def _resnet_small(self, inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, include_root_block=True, diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index ceaf83b70a76e8a1195b4c177f4764dc7ab792f2..c8d0c14e1951a7c29eed096d2a2e9849c4326245 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -106,7 +106,8 @@ def summary_writer_function(name, tensor, function, family=None): function(tag, scope) return True - return control_flow_ops.cond(should_record_summaries(), record, _nothing) + return control_flow_ops.cond( + should_record_summaries(), record, _nothing, name="") def generic(name, tensor, metadata, family=None): diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index c9a9bb3d5b17c309e136f902505bf1fc9e5295aa..6958ee8dd83600d130293322c8680b3c0c0c02b2 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -79,6 +79,23 @@ class TargetTest(test_util.TensorFlowTestCase): event.ParseFromString(records[1]) self.assertEqual(event.summary.value[0].simple_value, 2.0) + def testSummaryName(self): + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t2') + summary_ops.always_record_summaries() + + summary_ops.scalar('scalar', 2.0) + + self.assertTrue(gfile.Exists(logdir)) + files = gfile.ListDirectory(logdir) + self.assertEqual(len(files), 1) + records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + self.assertEqual(len(records), 2) + event = event_pb2.Event() + event.ParseFromString(records[1]) + self.assertEqual(event.summary.value[0].tag, 'scalar') + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index 29e0d6af78e05cc0c509acd540a67519d557e57a..b9aad36f3d25b9fb7b8b525be54fb7a39394b373 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -271,9 +271,6 @@ class TraverseTreeV4Op : public OpKernel { string serialized_proto; OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto)); input_spec_.ParseFromString(serialized_proto); - - data_set_ = - std::unique_ptr(new TensorDataSet(input_spec_, 0)); } void Compute(OpKernelContext* context) override { @@ -282,8 +279,9 @@ class TraverseTreeV4Op : public OpKernel { const Tensor& sparse_input_values = context->input(3); const Tensor& sparse_input_shape = context->input(4); - data_set_->set_input_tensors(input_data, sparse_input_indices, - sparse_input_values, sparse_input_shape); + std::unique_ptr data_set(new TensorDataSet(input_spec_, 0)); + data_set->set_input_tensors(input_data, sparse_input_indices, + sparse_input_values, sparse_input_shape); DecisionTreeResource* decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), @@ -291,7 +289,7 @@ class TraverseTreeV4Op : public OpKernel { mutex_lock l(*decision_tree_resource->get_mutex()); core::ScopedUnref unref_me(decision_tree_resource); - const int num_data = data_set_->NumItems(); + const int num_data = data_set->NumItems(); Tensor* output_predictions = nullptr; TensorShape output_shape; @@ -306,11 +304,11 @@ class TraverseTreeV4Op : public OpKernel { auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; const int64 costPerTraverse = 500; - auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data]( - int64 start, int64 end) { + auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource, + num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); - TraverseTree(decision_tree_resource, data_set_, static_cast(start), + TraverseTree(decision_tree_resource, data_set, static_cast(start), static_cast(end), set_leaf_ids, nullptr); }; Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, @@ -319,7 +317,6 @@ class TraverseTreeV4Op : public OpKernel { private: tensorforest::TensorForestDataSpec input_spec_; - std::unique_ptr data_set_; TensorForestParams param_proto_; }; diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index b6d57ef952777bc204f9534e60f2ce7de3687615..f80a34ece662d1e0b0ea1cb7616fa1b5b84731fa 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -235,9 +235,6 @@ class ProcessInputOp : public OpKernel { string serialized_proto; OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto)); input_spec_.ParseFromString(serialized_proto); - - data_set_ = std::unique_ptr( - new TensorDataSet(input_spec_, random_seed_)); } void Compute(OpKernelContext* context) override { @@ -249,8 +246,9 @@ class ProcessInputOp : public OpKernel { const Tensor& input_weights = context->input(7); const Tensor& leaf_ids_tensor = context->input(8); - data_set_->set_input_tensors(input_data, sparse_input_indices, - sparse_input_values, sparse_input_shape); + std::unique_ptr data_set(new TensorDataSet(input_spec_, 0)); + data_set->set_input_tensors(input_data, sparse_input_indices, + sparse_input_values, sparse_input_shape); FertileStatsResource* fertile_stats_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), @@ -264,7 +262,7 @@ class ProcessInputOp : public OpKernel { core::ScopedUnref unref_stats(fertile_stats_resource); core::ScopedUnref unref_tree(tree_resource); - const int32 num_data = data_set_->NumItems(); + const int32 num_data = data_set->NumItems(); auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; @@ -308,23 +306,23 @@ class ProcessInputOp : public OpKernel { // from a digits run on local desktop. Heuristics might be necessary // if it really matters that much. const int64 costPerUpdate = 1000; - auto update = [this, &target, &leaf_ids_tensor, &num_targets, + auto update = [this, &target, &leaf_ids_tensor, &num_targets, &data_set, fertile_stats_resource, &locks, &set_lock, &ready_to_split, num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); - UpdateStats(fertile_stats_resource, data_set_, target, num_targets, + UpdateStats(fertile_stats_resource, data_set, target, num_targets, leaf_ids_tensor, &locks, &set_lock, static_cast(start), static_cast(end), &ready_to_split); }; auto update_collated = [this, &target, &num_targets, fertile_stats_resource, tree_resource, &leaf_examples, &set_lock, - &ready_to_split, + &ready_to_split, &data_set, num_leaves](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_leaves); - UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set_, + UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set, target, num_targets, leaf_examples, &set_lock, static_cast(start), static_cast(end), &ready_to_split); @@ -350,7 +348,6 @@ class ProcessInputOp : public OpKernel { private: int32 random_seed_; tensorforest::TensorForestDataSpec input_spec_; - std::unique_ptr data_set_; TensorForestParams param_proto_; }; diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 756533250a62d9eb01ae9d2c80125272aeabca4c..eb938763f12efd9281bec4321384acd4617cdfcf 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -470,7 +470,11 @@ class RandomForestGraphs(object): """Constructs a TF graph for evaluating a random forest. Args: - input_data: A tensor or dict of string->Tensor for input data. + input_data: A tensor or dict of string->Tensor for the input data. + This input_data must generate the same spec as the + input_data used in training_graph: the dict must have + the same keys, for example, and all tensors must have + the same size in their first dimension. **inference_args: Keyword arguments to pass through to each tree. Returns: diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 015d0eba29f281d78ed6717271987cf3f2e121e9..222a77c4898bf705f98f98fba841bbfff5e852cc 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -25,6 +25,7 @@ py_test( srcs = ["predict_test.py"], data = ["data/period_trend.csv"], srcs_version = "PY2AND3", + tags = ["notsan"], # b/67513579 deps = [ ":predict", "//tensorflow/python:client_testlib", @@ -96,6 +97,7 @@ py_test( timeout = "long", # Moderate but for asan srcs = ["lstm_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [":lstm"], ) diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py index 6bab06f56c859705597027369147643a43ce01c0..3ba823f638da8f750981bc910d960706ff652fb7 100644 --- a/tensorflow/contrib/timeseries/examples/lstm.py +++ b/tensorflow/contrib/timeseries/examples/lstm.py @@ -106,16 +106,6 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): for state_element in self._lstm_cell.zero_state(batch_size=1, dtype=self.dtype)]) - def _transform(self, data): - """Normalize data based on input statistics to encourage stable training.""" - mean, variance = self._input_statistics.overall_feature_moments - return (data - mean) / variance - - def _de_transform(self, data): - """Transform data back to the input scale.""" - mean, variance = self._input_statistics.overall_feature_moments - return data * variance + mean - def _filtering_step(self, current_times, current_values, state, predictions): """Update model state based on observations. @@ -140,7 +130,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): state_from_time, prediction, lstm_state = state with tf.control_dependencies( [tf.assert_equal(current_times, state_from_time)]): - transformed_values = self._transform(current_values) + # Subtract the mean and divide by the variance of the series. Slightly + # more efficient if done for a whole window (using the normalize_features + # argument to SequentialTimeSeriesModel). + transformed_values = self._scale_data(current_values) # Use mean squared error across features for the loss. predictions["loss"] = tf.reduce_mean( (prediction - transformed_values) ** 2, axis=-1) @@ -156,7 +149,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): inputs=previous_observation_or_prediction, state=lstm_state) next_prediction = self._predict_from_lstm_output(lstm_output) new_state_tuple = (current_times, next_prediction, new_lstm_state) - return new_state_tuple, {"mean": self._de_transform(next_prediction)} + return new_state_tuple, {"mean": self._scale_back_data(next_prediction)} def _imputation_step(self, current_times, state): """Advance model state across a gap.""" diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index da583a2ba0c063a55dc149a26b2c6c9d771e1a2a..76e8ccc62a2d34acf333515043d20afc456b1924 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -371,6 +371,7 @@ py_test( "ar_model_test.py", ], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":ar_model", ":estimators", diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index 7f85a04158b10545df9c0b5fa4506f955f39cf4a..ff140efd48104e386826eab7abbc94bec220f9df 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -89,8 +89,6 @@ class ARModel(model.TimeSeriesModel): self.hidden_layer_sizes = hidden_layer_sizes self.window_size = self.input_window_size + self.output_window_size self.loss = loss - self.stats_means = None - self.stats_sigmas = None super(ARModel, self).__init__( num_features=num_features) assert num_time_buckets > 0 @@ -106,32 +104,6 @@ class ARModel(model.TimeSeriesModel): assert len(self._periods) or self.input_window_size assert output_window_size > 0 - def scale_data(self, data): - """Scale data according to stats.""" - if self._input_statistics is not None: - return (data - self.stats_means) / self.stats_sigmas - else: - return data - - def scale_back_data(self, data): - if self._input_statistics is not None: - return (data * self.stats_sigmas) + self.stats_means - else: - return data - - def scale_back_variance(self, var): - if self._input_statistics is not None: - return var * self.stats_sigmas * self.stats_sigmas - else: - return var - - def initialize_graph(self, input_statistics=None): - super(ARModel, self).initialize_graph(input_statistics=input_statistics) - if self._input_statistics: - self.stats_means, variances = ( - self._input_statistics.overall_feature_moments) - self.stats_sigmas = math_ops.sqrt(variances) - def get_start_state(self): # State which matches the format we'll return later. Typically this will not # be used by the model directly, but the shapes and dtypes should match so @@ -388,8 +360,8 @@ class ARModel(model.TimeSeriesModel): predicted_covariance = array_ops.ones_like(predicted_mean) # Transform and scale the mean and covariance appropriately. - predicted_mean = self.scale_back_data(predicted_mean) - predicted_covariance = self.scale_back_variance(predicted_covariance) + predicted_mean = self._scale_back_data(predicted_mean) + predicted_covariance = self._scale_back_variance(predicted_covariance) return {"mean": predicted_mean, "covariance": predicted_covariance} @@ -418,7 +390,7 @@ class ARModel(model.TimeSeriesModel): times_feature=TrainEvalFeatures.TIMES, window_size=self.window_size, times_shape=times.get_shape())) - values = self.scale_data(values) + values = self._scale_data(values) if self.input_window_size > 0: input_values = values[:, :self.input_window_size, :] else: @@ -435,14 +407,14 @@ class ARModel(model.TimeSeriesModel): # (observed - predicted) ** 2. # Note that this affects only evaluation; the training loss is unaffected. loss = self.loss_op( - self.scale_back_data(targets), - {"mean": self.scale_back_data(prediction_ops["mean"])}) + self._scale_back_data(targets), + {"mean": self._scale_back_data(prediction_ops["mean"])}) else: loss = self.loss_op(targets, prediction_ops) # Scale back the prediction. - prediction = self.scale_back_data(prediction) - covariance = self.scale_back_variance(covariance) + prediction = self._scale_back_data(prediction) + covariance = self._scale_back_variance(covariance) return model.ModelOutputs( loss=loss, @@ -565,7 +537,7 @@ class ARModel(model.TimeSeriesModel): new_state_times.set_shape((None, self.input_window_size)) new_state_values = array_ops.concat( [previous_state_values, - self.scale_data(values)], axis=1)[:, -self.input_window_size:, :] + self._scale_data(values)], axis=1)[:, -self.input_window_size:, :] new_state_values.set_shape((None, self.input_window_size, self.num_features)) else: diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index 3308f620d9624b38479ca9010ab969e75483b17e..3738dfa154d4f39b9562446972443ed88f3fbe8b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -20,8 +20,8 @@ from __future__ import print_function from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import feature_keys -from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib +from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries import state_management from tensorflow.contrib.timeseries.python.timeseries.state_space_models import state_space_model from tensorflow.contrib.timeseries.python.timeseries.state_space_models import structural_ensemble @@ -59,9 +59,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator): if optimizer is None: optimizer = train.AdamOptimizer(0.02) self._model = model - model_fn = ts_head_lib.time_series_regression_head( + ts_regression_head = ts_head_lib.time_series_regression_head( model, state_manager, optimizer, - input_statistics_generator=input_statistics_generator).create_estimator_spec + input_statistics_generator=input_statistics_generator) + model_fn = ts_regression_head.create_estimator_spec super(TimeSeriesRegressor, self).__init__( model_fn=model_fn, model_dir=model_dir, diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index a8e22566cda086d414231fe81632252ab530ca50..5896fc2a206bc747688b5b012e0f87465592dd8a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -1,3 +1,18 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Timeseries head.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -22,31 +37,34 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest -def time_series_regression_head( - model, state_manager, optimizer, input_statistics_generator=None): +def time_series_regression_head(model, + state_manager, + optimizer, + input_statistics_generator=None): """Creates a `_Head` for time series regression. Args: - weight_column: A string or a `_NumericColumn` created by - `tf.feature_column.numeric_column` defining feature column representing - weights. It is used to down weight or boost examples during training. It - will be multiplied by the loss of the example. - label_dimension: Number of regression labels per example. This is the size - of the last dimension of the labels `Tensor` (typically, this has shape - `[batch_size, label_dimension]`). + model: A model for time series regression. + state_manager: A state manager. + optimizer: An optimizer. + input_statistics_generator: A input statistics generator. Returns: An instance of `_Head` for time series regression. """ - return _TimeSeriesRegressionHead( - model, state_manager, optimizer, input_statistics_generator) + return _TimeSeriesRegressionHead(model, state_manager, optimizer, + input_statistics_generator) class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access """See `time_series_regression_head`.""" - def __init__(self, model, state_manager, optimizer, - input_statistics_generator=None, name=None): + def __init__(self, + model, + state_manager, + optimizer, + input_statistics_generator=None, + name=None): self.model = model self.state_manager = state_manager self.optimizer = optimizer @@ -56,31 +74,33 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc def _train_ops(self, features): """Add training ops to the graph.""" with variable_scope.variable_scope("model"): - model_outputs = self.state_manager.define_loss(self.model, features, - estimator_lib.ModeKeys.TRAIN) + model_outputs = self.state_manager.define_loss( + self.model, features, estimator_lib.ModeKeys.TRAIN) + train_op = optimizers.optimize_loss( - model_outputs.loss, - global_step=variables.get_global_step(), - optimizer=self.optimizer, - # Learning rate is set in the Optimizer object - learning_rate=None) + model_outputs.loss, + global_step=variables.get_global_step(), + optimizer=self.optimizer, + # Learning rate is set in the Optimizer object + learning_rate=None) return estimator_lib.EstimatorSpec( - loss=model_outputs.loss, - mode=estimator_lib.ModeKeys.TRAIN, - train_op=train_op) + loss=model_outputs.loss, + mode=estimator_lib.ModeKeys.TRAIN, + train_op=train_op) - # TODO: suffix summary and metrics keys by `"/" + name` + # TODO(terrytangyuan): suffix summary and metrics keys by `"/" + name` @property def name(self): return self._name - # TOOD: unused for now. Need to decouple `state_manager.define_loss` - # to satisfy the extendable return signature of `_Head.create_loss`. + # TODO(terrytangyuan): unused for now. Need to decouple + # `state_manager.define_loss` to satisfy the extendable return signature of + # `_Head.create_loss`. def create_loss(self, features, mode, logits, labels): """See `_Head`.""" return None - # TODO: check label dimension + # TODO(terrytangyuan): check label dimension @property def logits_dimension(self): return None @@ -88,58 +108,59 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc def _evaluate_ops(self, features): """Add ops for evaluation (aka filtering) to the graph.""" with variable_scope.variable_scope("model"): - model_outputs = self.state_manager.define_loss(self.model, features, - estimator_lib.ModeKeys.EVAL) + model_outputs = self.state_manager.define_loss( + self.model, features, estimator_lib.ModeKeys.EVAL) metrics = {} # Just output in-sample predictions for the last chunk seen for prediction_key, prediction_value in model_outputs.predictions.items(): metrics[prediction_key] = _identity_metric_single(prediction_key, prediction_value) metrics[feature_keys.FilteringResults.TIMES] = _identity_metric_single( - feature_keys.FilteringResults.TIMES, model_outputs.prediction_times) + feature_keys.FilteringResults.TIMES, model_outputs.prediction_times) metrics[feature_keys.FilteringResults.STATE_TUPLE] = ( - _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE, - model_outputs.end_state)) + _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE, + model_outputs.end_state)) return estimator_lib.EstimatorSpec( - loss=model_outputs.loss, - mode=estimator_lib.ModeKeys.EVAL, - eval_metric_ops=metrics, - predictions={}) + loss=model_outputs.loss, + mode=estimator_lib.ModeKeys.EVAL, + eval_metric_ops=metrics, + predictions={}) def _predict_ops(self, features): """Add ops for prediction to the graph.""" with variable_scope.variable_scope("model"): prediction = self.model.predict(features=features) prediction[feature_keys.PredictionResults.TIMES] = features[ - feature_keys.PredictionFeatures.TIMES] + feature_keys.PredictionFeatures.TIMES] return estimator_lib.EstimatorSpec( - predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT) + predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT) def _serving_ops(self, features): """Add ops for serving to the graph.""" with variable_scope.variable_scope("model"): prediction_outputs = self.model.predict(features=features) with variable_scope.variable_scope("model", reuse=True): - filtering_outputs = self.state_manager.define_loss(self.model, features, - estimator_lib.ModeKeys.EVAL) + filtering_outputs = self.state_manager.define_loss( + self.model, features, estimator_lib.ModeKeys.EVAL) + return estimator_lib.EstimatorSpec( - mode=estimator_lib.ModeKeys.PREDICT, - export_outputs={ - feature_keys.SavedModelLabels.PREDICT: - export_lib.PredictOutput(prediction_outputs), - feature_keys.SavedModelLabels.FILTER: - export_lib.PredictOutput( - state_to_dictionary(filtering_outputs.end_state)) - }, - # Likely unused, but it is necessary to return `predictions` to satisfy - # the Estimator's error checking. - predictions={}) + mode=estimator_lib.ModeKeys.PREDICT, + export_outputs={ + feature_keys.SavedModelLabels.PREDICT: + export_lib.PredictOutput(prediction_outputs), + feature_keys.SavedModelLabels.FILTER: + export_lib.PredictOutput( + state_to_dictionary(filtering_outputs.end_state)) + }, + # Likely unused, but it is necessary to return `predictions` to satisfy + # the Estimator's error checking. + predictions={}) def _convert_feature_to_tensor(self, name, value): """Casts features to the correct dtype based on their name.""" if name in [ - feature_keys.TrainEvalFeatures.TIMES, - feature_keys.PredictionFeatures.TIMES + feature_keys.TrainEvalFeatures.TIMES, + feature_keys.PredictionFeatures.TIMES ]: return math_ops.cast(value, dtypes.int64) if name == feature_keys.TrainEvalFeatures.VALUES: @@ -164,39 +185,45 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc del features[key] numbered_state.sort(key=lambda number, *_: number) features[feature_keys.State.STATE_TUPLE] = nest.pack_sequence_as( - structure=self.model.get_start_state(), - flat_sequence=[tensor for _, _, tensor in numbered_state]) + structure=self.model.get_start_state(), + flat_sequence=[tensor for _, _, tensor in numbered_state]) return features, True def create_estimator_spec(self, features, mode, labels=None): """Performs basic error checking and returns an EstimatorSpec.""" with ops.name_scope("head"): if labels: - raise ValueError("The model received a `labels` dictionary, which is not" - " supported. Pass '{}' and '{}' as features.".format( - feature_keys.TrainEvalFeatures.TIMES, - feature_keys.TrainEvalFeatures.VALUES)) + raise ValueError( + "The model received a `labels` dictionary, which is " + "not supported. Pass '{}' and '{}' as " + "features.".format(feature_keys.TrainEvalFeatures.TIMES, + feature_keys.TrainEvalFeatures.VALUES)) del labels - features = {name: self._convert_feature_to_tensor(name=name, value=value) - for name, value in features.items()} + features = { + name: self._convert_feature_to_tensor(name=name, value=value) + for name, value in features.items() + } if self.input_statistics_generator is not None: input_statistics = self.input_statistics_generator.initialize_graph( - features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN)) + features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN)) else: input_statistics = None self.model.initialize_graph(input_statistics=input_statistics) - # _gather_state requires the model to have its graph initialized (so it has - # access to the structure of the model's state) + + # _gather_state requires the model to have its graph initialized (so it + # has access to the structure of the model's state) features, passed_flat_state = self._gather_state(features) - if (mode == estimator_lib.ModeKeys.TRAIN - or mode == estimator_lib.ModeKeys.EVAL): + if (mode == estimator_lib.ModeKeys.TRAIN or + mode == estimator_lib.ModeKeys.EVAL): _check_train_eval_features(features, self.model) elif mode == estimator_lib.ModeKeys.PREDICT: _check_predict_features(features) else: raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode)) + self.state_manager.initialize_graph( - model=self.model, input_statistics=input_statistics) + model=self.model, input_statistics=input_statistics) + if mode == estimator_lib.ModeKeys.TRAIN: return self._train_ops(features) elif mode == estimator_lib.ModeKeys.EVAL: @@ -210,8 +237,10 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc return self._serving_ops(features) -def _check_feature_shapes_compatible_with( - features, compatible_with_name, compatible_with_value, ignore=None): +def _check_feature_shapes_compatible_with(features, + compatible_with_name, + compatible_with_value, + ignore=None): """Checks all features are compatible with the given time-like feature.""" if ignore is None: ignore = set() @@ -223,77 +252,77 @@ def _check_feature_shapes_compatible_with( continue if feature_shape.ndims < 2: raise ValueError( - ("Features must have shape (batch dimension, window size, ...) " - "(got rank {} for feature '{}')").format( - feature_shape.ndims, name)) + ("Features must have shape (batch dimension, window size, ...) " + "(got rank {} for feature '{}')").format(feature_shape.ndims, name)) if not feature_shape[:2].is_compatible_with( - compatible_with_value.get_shape()): + compatible_with_value.get_shape()): raise ValueError( - ("Features must have shape (batch dimension, window size, ...) " - "where batch dimension and window size match the " - "'{times_feature}' feature (got shape {feature_shape} for " - "feature '{feature_name}' but shape {times_shape} for feature " - "'{times_feature}')").format( - times_feature=compatible_with_name, - feature_shape=feature_shape, - feature_name=name, - times_shape=compatible_with_value.get_shape())) + ("Features must have shape (batch dimension, window size, ...) " + "where batch dimension and window size match the " + "'{times_feature}' feature (got shape {feature_shape} for " + "feature '{feature_name}' but shape {times_shape} for feature " + "'{times_feature}')").format( + times_feature=compatible_with_name, + feature_shape=feature_shape, + feature_name=name, + times_shape=compatible_with_value.get_shape())) def _check_predict_features(features): """Raises errors if features are not suitable for prediction.""" if feature_keys.PredictionFeatures.TIMES not in features: raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.TIMES)) + feature_keys.PredictionFeatures.TIMES)) if feature_keys.PredictionFeatures.STATE_TUPLE not in features: raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.STATE_TUPLE)) + feature_keys.PredictionFeatures.STATE_TUPLE)) times_feature = features[feature_keys.PredictionFeatures.TIMES] if not times_feature.get_shape().is_compatible_with([None, None]): raise ValueError( - ("Expected shape (batch dimension, window size) for feature '{}' " - "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, - times_feature.get_shape())) + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, + times_feature.get_shape())) _check_feature_shapes_compatible_with( - features=features, - compatible_with_name=feature_keys.PredictionFeatures.TIMES, - compatible_with_value=times_feature, - ignore=set([ - feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes - ])) + features=features, + compatible_with_name=feature_keys.PredictionFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes + ])) def _check_train_eval_features(features, model): """Raise errors if features are not suitable for training/evaluation.""" if feature_keys.TrainEvalFeatures.TIMES not in features: raise ValueError("Expected a '{}' feature for training/evaluation.".format( - feature_keys.TrainEvalFeatures.TIMES)) + feature_keys.TrainEvalFeatures.TIMES)) if feature_keys.TrainEvalFeatures.VALUES not in features: raise ValueError("Expected a '{}' feature for training/evaluation.".format( - feature_keys.TrainEvalFeatures.VALUES)) + feature_keys.TrainEvalFeatures.VALUES)) times_feature = features[feature_keys.TrainEvalFeatures.TIMES] if not times_feature.get_shape().is_compatible_with([None, None]): raise ValueError( - ("Expected shape (batch dimension, window size) for feature '{}' " - "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES, - times_feature.get_shape())) + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES, + times_feature.get_shape())) values_feature = features[feature_keys.TrainEvalFeatures.VALUES] if not values_feature.get_shape().is_compatible_with( - [None, None, model.num_features]): + [None, None, model.num_features]): raise ValueError( - ("Expected shape (batch dimension, window size, {num_features}) " - "for feature '{feature_name}', since the model was configured " - "with num_features={num_features} (got shape {got_shape})").format( - num_features=model.num_features, - feature_name=feature_keys.TrainEvalFeatures.VALUES, - got_shape=times_feature.get_shape())) + ("Expected shape (batch dimension, window size, {num_features}) " + "for feature '{feature_name}', since the model was configured " + "with num_features={num_features} (got shape {got_shape})").format( + num_features=model.num_features, + feature_name=feature_keys.TrainEvalFeatures.VALUES, + got_shape=times_feature.get_shape())) _check_feature_shapes_compatible_with( - features=features, - compatible_with_name=feature_keys.TrainEvalFeatures.TIMES, - compatible_with_value=times_feature, - ignore=set([ - feature_keys.State.STATE_TUPLE # Model-dependent shapes - ])) + features=features, + compatible_with_name=feature_keys.TrainEvalFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + feature_keys.State.STATE_TUPLE # Model-dependent shapes + ])) + def _identity_metric_single(name, input_tensor): """A metric which takes on its last updated value. @@ -311,12 +340,12 @@ def _identity_metric_single(name, input_tensor): A tuple of (value, update_op). """ metric_variable = variable_scope.variable( - name="{}_identity_metric".format(name), - initial_value=array_ops.zeros([], dtype=input_tensor.dtype), - collections=[ops.GraphKeys.LOCAL_VARIABLES], - validate_shape=False) - update_op = state_ops.assign(metric_variable, input_tensor, - validate_shape=False) + name="{}_identity_metric".format(name), + initial_value=array_ops.zeros([], dtype=input_tensor.dtype), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + validate_shape=False) + update_op = state_ops.assign( + metric_variable, input_tensor, validate_shape=False) # This shape will be correct once the first update runs (but may be # incomplete, so is not helpful for initializing the variable). metric_variable.set_shape(input_tensor.get_shape()) @@ -329,13 +358,13 @@ def _identity_metric_nested(name, input_tensors): value_tensors = [] for tensor_number, tensor in enumerate(nest.flatten(input_tensors)): value_tensor, update_op = _identity_metric_single( - name="{}_{}".format(name, tensor_number), - input_tensor=tensor) + name="{}_{}".format(name, tensor_number), input_tensor=tensor) update_ops.append(update_op) value_tensors.append(value_tensor) return (nest.pack_sequence_as(input_tensors, value_tensors), control_flow_ops.group(*update_ops)) + def state_to_dictionary(state_tuple): """Flatten model state into a dictionary with string keys.""" flattened = {} @@ -344,4 +373,3 @@ def state_to_dictionary(state_tuple): state_number) flattened[prefixed_state_name] = state_value return flattened - diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index 7ebcebfe1b156a6c0cc86fa1ded55e4a645d291f..3415061cfd87358cccaf36dcb301fb36986bbde6 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.timeseries.python.timeseries import feature_keys -from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib +from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import state_management from tensorflow.python.estimator import estimator_lib diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index c70da3e082245e76ab3225676c2d37c4ea95292d..23452a81c397da3516016d72b7bc9b80f7d6447f 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -936,8 +936,7 @@ class InputStatisticsFromMiniBatch(object): start_time = variable_scope.get_variable( name="start_time", dtype=dtypes.int64, - initializer=init_ops.zeros_initializer(), - shape=[], + initializer=dtypes.int64.max, trainable=False) total_observation_count = variable_scope.get_variable( name="total_observation_count", diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py index f2ef8d22114be50a10d3b106be5e144cc70b4bfc..b32b5c5494ae14187954b900119678a5b53a3602 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model.py @@ -80,6 +80,8 @@ class TimeSeriesModel(object): self.dtype = dtype self._input_statistics = None self._graph_initialized = False + self._stats_means = None + self._stats_sigmas = None # TODO(allenl): Move more of the generic machinery for generating and # predicting into TimeSeriesModel, and possibly share it between generate() @@ -120,6 +122,38 @@ class TimeSeriesModel(object): """ self._graph_initialized = True self._input_statistics = input_statistics + if self._input_statistics: + self._stats_means, variances = ( + self._input_statistics.overall_feature_moments) + self._stats_sigmas = math_ops.sqrt(variances) + + def _scale_data(self, data): + """Scale data according to stats (input scale -> model scale).""" + if self._input_statistics is not None: + return (data - self._stats_means) / self._stats_sigmas + else: + return data + + def _scale_variance(self, variance): + """Scale variances according to stats (input scale -> model scale).""" + if self._input_statistics is not None: + return variance / self._input_statistics.overall_feature_moments.variance + else: + return variance + + def _scale_back_data(self, data): + """Scale back data according to stats (model scale -> input scale).""" + if self._input_statistics is not None: + return (data * self._stats_sigmas) + self._stats_means + else: + return data + + def _scale_back_variance(self, variance): + """Scale back variances according to stats (model scale -> input scale).""" + if self._input_statistics is not None: + return variance * self._input_statistics.overall_feature_moments.variance + else: + return variance def _check_graph_initialized(self): if not self._graph_initialized: @@ -304,6 +338,7 @@ class SequentialTimeSeriesModel(TimeSeriesModel): train_output_names, predict_output_names, num_features, + normalize_features=False, dtype=dtypes.float32, exogenous_feature_columns=None, exogenous_update_condition=None, @@ -316,6 +351,12 @@ class SequentialTimeSeriesModel(TimeSeriesModel): predict_output_names: A list of products/predictions returned from _prediction_step. num_features: Number of features for the time series + normalize_features: Boolean. If True, `values` are passed normalized to + the model (via self._scale_data). Scaling is done for the whole window + as a batch, which is slightly more efficient than scaling inside the + window loop. The model must then define _scale_back_predictions, which + may use _scale_back_data or _scale_back_variance to return predictions + to the input scale. dtype: The floating point datatype to use. exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn objects. See `TimeSeriesModel`. @@ -344,9 +385,25 @@ class SequentialTimeSeriesModel(TimeSeriesModel): self._exogenous_update_condition = exogenous_update_condition self._train_output_names = train_output_names self._predict_output_names = predict_output_names + self._normalize_features = normalize_features self._static_unrolling_window_size_threshold = ( static_unrolling_window_size_threshold) + def _scale_back_predictions(self, predictions): + """Return a window of predictions to input scale. + + Args: + predictions: A dictionary mapping from prediction names to Tensors. + Returns: + A dictionary with values corrected for input normalization (e.g. with + self._scale_back_mean and possibly self._scale_back_variance). May be a + mutated version of the argument. + """ + raise NotImplementedError( + "SequentialTimeSeriesModel normalized input data" + " (normalize_features=True), but no method was provided to transform " + "the predictions back to the input scale.") + @abc.abstractmethod def _filtering_step(self, current_times, current_values, state, predictions): """Compute a single-step loss for a batch of data. @@ -524,6 +581,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel): self._check_graph_initialized() times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtype=dtypes.int64) values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype) + if self._normalize_features: + values = self._scale_data(values) exogenous_regressors = self._process_exogenous_features( times=times, features={key: value for key, value in features.items() @@ -556,6 +615,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel): # Since we have window-level additions to the loss, its per-step value is # misleading, so we avoid returning it. del outputs["loss"] + if self._normalize_features: + outputs = self._scale_back_predictions(outputs) return per_observation_loss, state, outputs def predict(self, features): @@ -583,6 +644,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel): times=predict_times, state=start_state, state_update_fn=_call_prediction_step, outputs=self._predict_output_names) + if self._normalize_features: + predictions = self._scale_back_predictions(predictions) return predictions class _FakeTensorArray(object): diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py index b9d3f55c39d32bb9f14829842fcad85571de6855..56167c4f012b42a4e7d56c5e6eac7862d50bd59b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py @@ -57,7 +57,9 @@ class AdderStateSpaceModel(state_space_model.StateSpaceModel): # TODO(allenl): Better support for multivariate series here. initial_value = array_ops.stack([ math_ops.reduce_mean( - self._input_statistics.series_start_moments.mean), 0. + self._scale_data( + self._input_statistics.series_start_moments.mean)), + 0. ]) return initial_value + variable_scope.get_variable( name="prior_state_mean", diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py index 6a9660b400d08a0397103676344ea1969fbc1f7a..6257002647ed53bbde3ace11a6b45e4e2cdeb57d 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py @@ -232,6 +232,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): + filtering_postprocessor_names), predict_output_names=["mean", "covariance"], num_features=configuration.num_features, + normalize_features=True, dtype=configuration.dtype, exogenous_feature_columns=configuration.exogenous_feature_columns, exogenous_update_condition=configuration.exogenous_update_condition, @@ -309,15 +310,10 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): _, _, priors_from_time = state times = ops.convert_to_tensor(times) priors_from_time = ops.convert_to_tensor(priors_from_time) - with ops.control_dependencies([ - control_flow_ops.Assert( - math_ops.reduce_all(priors_from_time <= times[:, 0]), - [priors_from_time, times[:, 0]], - summarize=100) - ]): - times = array_ops.identity(times) intra_batch_gaps = array_ops.reshape(times[:, 1:] - times[:, :-1], [-1]) - starting_gaps = times[:, 0] - priors_from_time + # Ignore negative starting gaps, since there will be transient start times + # as inputs statistics are computed. + starting_gaps = math_ops.maximum(times[:, 0] - priors_from_time, 0) # Pre-define transition matrices raised to powers (and their sums) for every # gap in this window. This avoids duplicate computation (for example many # steps will use the transition matrix raised to the first power) and @@ -369,20 +365,15 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): Imputed model state corresponding to the `state` argument. """ estimated_state, estimated_state_var, previous_times = state - catchup_times = current_times - previous_times - non_negative_assertion = control_flow_ops.Assert( - math_ops.reduce_all(catchup_times >= 0), [ - "Negative imputation interval", catchup_times, current_times, - previous_times - ], - summarize=100) - with ops.control_dependencies([non_negative_assertion]): - transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking - self._cached_transition_powers_and_sums(catchup_times)) - estimated_state = self._kalman_filter.predict_state_mean( - estimated_state, transition_matrices) - estimated_state_var = self._kalman_filter.predict_state_var( - estimated_state_var, transition_matrices, transition_noise_sums) + # Ignore negative imputation intervals due to transient start time + # estimates. + catchup_times = math_ops.maximum(current_times - previous_times, 0) + transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking + self._cached_transition_powers_and_sums(catchup_times)) + estimated_state = self._kalman_filter.predict_state_mean( + estimated_state, transition_matrices) + estimated_state_var = self._kalman_filter.predict_state_var( + estimated_state_var, transition_matrices, transition_noise_sums) return (estimated_state, estimated_state_var, previous_times + catchup_times) @@ -437,6 +428,13 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): outputs=predictions) return (filtered_state, predictions) + def _scale_back_predictions(self, predictions): + """Return a window of predictions to input scale.""" + predictions["mean"] = self._scale_back_data(predictions["mean"]) + predictions["covariance"] = self._scale_back_variance( + predictions["covariance"]) + return predictions + def _prediction_step(self, current_times, state): """Make a prediction based on `state`. @@ -458,7 +456,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): """ estimated_state, estimated_state_var, previous_times = state advanced_to_current_assert = control_flow_ops.Assert( - math_ops.reduce_all(math_ops.equal(current_times, previous_times)), + math_ops.reduce_all(math_ops.less_equal(current_times, previous_times)), ["Attempted to predict without imputation"]) with ops.control_dependencies([advanced_to_current_assert]): observation_model = self.get_broadcasted_observation_model(current_times) @@ -475,6 +473,9 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): (self.num_features,))) predicted_obs_var.set_shape(current_times.get_shape().concatenate( (self.num_features, self.num_features))) + # Not scaled back to input-scale, since this also feeds into the + # loss. Instead, predictions are scaled back before being returned to the + # user in _scale_back_predictions. predictions = { "mean": predicted_obs, "covariance": predicted_obs_var} @@ -722,7 +723,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): # Make sure initial latent value uncertainty is at least on the same # scale as noise in the data. covariance_multiplier = math_ops.reduce_max( - self._input_statistics.series_start_moments.variance) + self._scale_variance( + self._input_statistics.series_start_moments.variance)) return base_covariance * gen_math_ops.maximum( covariance_multiplier, 1.0) else: @@ -920,7 +922,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): self.get_noise_transform(), dtype=self.dtype) state_noise_dimension = state_noise_transform.get_shape()[1].value if self._input_statistics is not None: - feature_variance = self._input_statistics.series_start_moments.variance + feature_variance = self._scale_variance( + self._input_statistics.series_start_moments.variance) initial_transition_noise_scale = math_ops.log( gen_math_ops.maximum( math_ops.reduce_mean(feature_variance) / math_ops.cast( @@ -945,7 +948,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): if self._input_statistics is not None: # Get variance across the first few values in each batch for each # feature, for an initial observation noise (over-)estimate. - feature_variance = self._input_statistics.series_start_moments.variance + feature_variance = self._scale_variance( + self._input_statistics.series_start_moments.variance) else: feature_variance = None if feature_variance is not None: diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py index 7c8f81ec5165b8ba7e8a1089953e5755b5a90915..ca57715e2b2e6bbadd276d641703c0a3b842652e 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py @@ -605,6 +605,7 @@ class TimeDependentStateSpaceModel(state_space_model.StateSpaceModel): super(TimeDependentStateSpaceModel, self).__init__( configuration=state_space_model.StateSpaceModelConfiguration( use_observation_noise=False, + transition_covariance_initial_log_scale_bias=5., static_unrolling_window_size_threshold= static_unrolling_window_size_threshold)) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py index 110ba9738f8c28109282b927fd07ade071bb3e4a..1afc58cfb240c52a9f001da787addfb7fbb46789 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py @@ -182,7 +182,8 @@ class VARMA(state_space_model.StateSpaceModel): # modeled as transition noise in VARMA, we set its initial value based on a # slight over-estimate empirical observation noise. if self._input_statistics is not None: - feature_variance = self._input_statistics.series_start_moments.variance + feature_variance = self._scale_variance( + self._input_statistics.series_start_moments.variance) initial_transition_noise_scale = math_ops.log( math_ops.maximum( math_ops.reduce_mean(feature_variance), minimum_initial_variance)) diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index e753fe7a5140028f238c2ff3754b1d7335ae8eb2..970fc97605057bd65fc5c0796f6a6a5f0a27e458 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -35,6 +35,7 @@ py_library( srcs = [ "python/tpu/tpu_config.py", "python/tpu/tpu_estimator.py", + "python/tpu/util.py", ], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index a40e2a7898a304c21a60929b30719f3132aec0f0..b40dac471708793d5a033279e2d2f4b4a0dac480 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -22,6 +22,11 @@ namespace tensorflow { using shape_inference::InferenceContext; using shape_inference::ShapeHandle; +REGISTER_OP("TPUReplicateMetadata") + .Attr("num_replicas: int >= 0") + .Attr("global_tpu_id: list(int) = []") + .SetShapeFn(shape_inference::UnknownShape); + REGISTER_OP("TPUReplicatedInput") .Input("inputs: N * T") .Output("output: T") diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 8d3344fac36be24a692f141eee140312d988a932..33e47f674d798f622fb08121dabb67d7f45af15b 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -21,9 +21,11 @@ from __future__ import print_function import platform +from tensorflow.python.framework import ops if platform.system() != "Windows": # pylint: disable=wildcard-import,unused-import,g-import-not-at-top + from tensorflow.contrib.tpu.ops import gen_tpu_ops from tensorflow.contrib.tpu.ops.gen_tpu_ops import * from tensorflow.contrib.util import loader @@ -32,6 +34,12 @@ if platform.system() != "Windows": _tpu_ops = loader.load_op_library( resource_loader.get_path_to_datafile("_tpu_ops.so")) + + @ops.RegisterGradient("CrossReplicaSum") + def _cross_replica_sum_grad(op, grad): + del op # Unused + # The gradient of a cross replica sum is also a cross-replica sum. + return gen_tpu_ops.cross_replica_sum(grad) else: # We have already built the appropriate libraries into the binary via CMake # if we have built contrib, so we don't need this diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index f6800e3e246dc5f6242a7bf127f6397fedf92b9f..fa5760953dfa9353a04f9af49b320f57a73cc275 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -105,9 +105,8 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): """A ControlFlowContext for nodes inside a TPU computation. The primary role of TPUReplicateContext is to mark operators inside a - tpu.replicate() computation with attributes: - * _tpu_replicate=XYZ, where XYZ is a unique name, and - * _tpu_num_replicas=k, where k is the number of replicas. + tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ + is a unique name. We use a ControlFlowContext to perform the annotation since it integrates with Tensorflow constructs like ResourceVariables. For example, @@ -116,11 +115,9 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): to build the variable's definition outside the replicated computation. """ - def __init__(self, name, num_replicas, global_tpu_id=None): + def __init__(self, name): control_flow_ops.ControlFlowContext.__init__(self) self._name = name - self._num_replicas = num_replicas - self._global_tpu_id = [] if global_tpu_id is None else global_tpu_id def AddOp(self, op): self._AddOpInternal(op) @@ -135,8 +132,6 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): if "_tpu_replicate" in op.node_def.attr: raise ValueError("TPU computations cannot be nested") op.node_def.attr["_tpu_replicate"].s = self._name - op.node_def.attr["_tpu_num_replicas"].i = self._num_replicas - op.node_def.attr["_tpu_global_id"].list.i.extend(self._global_tpu_id) op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) @@ -243,14 +238,15 @@ def replicate(computation, computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) - context = TPUReplicateContext( - name=graph.unique_name("cluster"), - num_replicas=num_replicas, - global_tpu_id=global_tpu_id) + context = TPUReplicateContext(name=graph.unique_name("cluster")) try: context.Enter() - with tpu_function.tpu_shard_context(num_replicas): + metadata = tpu_ops.tpu_replicate_metadata( + num_replicas=num_replicas, global_tpu_id=global_tpu_id) + + with tpu_function.tpu_shard_context( + num_replicas), ops.control_dependencies([metadata]): # The EncapsulateTPUComputations rewrite needs to identify the # replicated arguments inside each computation. Adds identity operators diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 02135bfe40e72860d474f441b6cc57430d4e0fca..79fd8b839b475350d4bedaacf2d48596020a0c4e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -21,12 +21,16 @@ from __future__ import print_function import collections +from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.python.estimator import run_config as run_config_lib class TPUConfig( collections.namedtuple('TPUConfig', [ - 'iterations_per_loop', 'num_shards', 'per_host_input_for_training' + 'iterations_per_loop', + 'num_shards', + 'per_host_input_for_training', + 'tpu_job_name', ])): """TPU related configuration required by `TPUEstimator`. @@ -36,28 +40,57 @@ class TPUConfig( global step is increased `iterations_per_loop` times in one `Session.run`. It is recommended to be set as number of global steps for next checkpoint. num_shards: The number of TPU shards in the system. - per_host_input_for_training: If `True`, `input_fn` is invoked per host - rather than per shard. Note: This behavior is going to be default as - `True` soon, so this flag will be removed after that. Also note that this - only works for single-host TPU training now. + per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host + rather than Per-Core. With Per-Host input pipeline deployment, `input_fn` + is invoked once on each host. To be precise, with a global batch size + `train_batch_size` in `TPUEstimator` constructor, the batch size for each + shard is `train_batch_size` // #hosts. With Per-Core input pipeline + deployment, the shard batch size is `train_batch_size` // #cores. Note + that this only works for single-host TPU training now (tracked in + b/67051042). For multi-host, please use Per-Core, i.e., `False` for + `per_host_input_for_training`. + tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred + within TPUEstimator, however when using ClusterSpec propagation in more + esoteric cluster configurations, you may need to specify the job name as a + string. """ def __new__(cls, iterations_per_loop=2, num_shards=2, - per_host_input_for_training=False): + per_host_input_for_training=True, + tpu_job_name=None): + + # Check iterations_per_loop. + util_lib.check_positive_integer(iterations_per_loop, + 'TPUConfig iterations_per_loop') + + # Check num_shards. + util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') return super(TPUConfig, cls).__new__( cls, iterations_per_loop=iterations_per_loop, num_shards=num_shards, - per_host_input_for_training=per_host_input_for_training) + per_host_input_for_training=per_host_input_for_training, + tpu_job_name=tpu_job_name) class RunConfig(run_config_lib.RunConfig): """RunConfig with TPU support.""" def __init__(self, tpu_config=None, evaluation_master='', master='', - **kwargs): + tf_random_seed=None, **kwargs): + """Constructs a RunConfig. + + Args: + tpu_config: the TPUConfig that specifies TPU-specific configuration. + evaluation_master: a string. The address of the master to use for eval. + master: a string. The address of the master to use for training. + tf_random_seed: an int. Sets the TensorFlow random seed. Defaults to None, + which initializes it randomly based on the environment. + """ + # We change the default random seed to None because that's a better default. + kwargs['tf_random_seed'] = tf_random_seed super(RunConfig, self).__init__(**kwargs) self._tpu_config = tpu_config or TPUConfig() self._evaluation_master = evaluation_master diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index b5001d596b8cb34b0cfd32df0864e466ab7d86b6..04e0719a1be90cb3b094109d737b4f0db5fa0ce2 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -31,6 +31,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_feed from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.core.protobuf import config_pb2 @@ -121,12 +122,55 @@ def _increase_eval_step_op(iterations_per_loop): use_locking=True) +_DEFAULT_JOB_NAME = 'tpu_worker' +_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' +_LOCAL_MASTERS = ('', 'local') + + def _tpu_job(run_config, mode): + """Returns the job name to use to place TPU computations on. + + Args: + run_config: The tpu_config.RunConfig used for this custom estimator. + mode: A model_fn_lib.ModeKeys value. + + Returns: + A string containing the job name, or None if no job should be specified. + + Raises: + ValueError: If the user needs to specify a tpu_job_name, because we are + unable to infer the job name automatically, or if the user-specified job + names are inappropriate. + """ + # If the user specifies the tpu_job_name, use that. + if run_config.tpu_config.tpu_job_name: + return run_config.tpu_config.tpu_job_name + # The tpu job is determined by the run_config. Right now, this method is # required as tpu_config is not part of the RunConfig. master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL else run_config.master) - return None if master in ['', 'local'] else 'tpu_worker' + if master in _LOCAL_MASTERS: + return None + + if (not run_config.session_config or + not run_config.session_config.cluster_def.job): + return _DEFAULT_JOB_NAME + cluster_def = run_config.session_config.cluster_def + job_names = set([job.name for job in cluster_def.job]) + if _DEFAULT_JOB_NAME in job_names: + # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. + raise ValueError('Currently, tpu_worker is not an allowed job name.') + if len(job_names) == 1: + return cluster_def.job[0].name + if len(job_names) == 2: + if _DEFAULT_COORDINATOR_JOB_NAME in job_names: + job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) + return job_names.pop() + # TODO(b/67716447): Include more sophisticated heuristics. + raise ValueError( + 'Could not infer TPU job name. Please specify a tpu_job_name as part of ' + 'your TPUConfig.') def _is_running_on_cpu(use_tpu, mode, eval_batch_size): @@ -268,17 +312,25 @@ class _InfeedThreadController(_InfeedOutfeedThreadBaseController): def _input_thread_fn_for_loading(self, session, enqueue_ops): count = 0 - while True: - signal = self._signal_queue.get() - if signal == _SIGNAL.STOP: - logging.info('Stop Infeed input thread.') - return - - iterations = signal - for i in range(iterations): - logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) - session.run(enqueue_ops) - count += 1 + try: + while True: + signal = self._signal_queue.get() + if signal == _SIGNAL.STOP: + logging.info('Stop Infeed input thread.') + return + + iterations = signal + for i in range(iterations): + logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) + session.run(enqueue_ops) + count += 1 + except Exception: # pylint: disable=broad-except + logging.error( + 'Failed running infeed, closing session.\n' + 'You may see an exception from your main session after this.', + exc_info=1 + ) + session.close() def join(self): logging.info('Waiting for Infeed Thread to exit.') @@ -1318,6 +1370,12 @@ class TPUEstimator(estimator_lib.Estimator): 'For TPU training, one of `steps` or `max_steps` must be set. ' 'Cannot be both `None`.') + # Estimator.train has explicit positiveness check. + if steps is not None: + util_lib.check_positive_integer(steps, 'Train steps') + if max_steps is not None: + util_lib.check_positive_integer(max_steps, 'Train max_steps') + return [_TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps)] @@ -1328,8 +1386,8 @@ class TPUEstimator(estimator_lib.Estimator): if steps is None: raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.') - if steps <= 0: - raise ValueError('Must specify steps > 0, given: {}'.format(steps)) + + util_lib.check_positive_integer(steps, 'Eval steps') hooks = [] hooks.append(evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access @@ -1609,3 +1667,5 @@ def _validate_tpu_training_graph(): if not cross_replica_sum_ops: raise ValueError( 'CrossShardOptimizer must be used for model training on TPUs.') + + diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py index d545a94ca6a2fdb3a9df2748b59300fd141dc55d..f8ba7d45e20b2f48e1409427665878df40a6db02 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py @@ -177,6 +177,10 @@ class ShardingPolicy(object): raise ValueError("shape %s does not contain shard_dimension %d" % (shape.as_list(), self._shard_dimension)) dims = shape.as_list() + if dims[self._shard_dimension] is None: + raise ValueError("shape %s must have a fixed size for dimension %d " + "that is known at graph construction time." % + (shape.as_list(), self._shard_dimension)) if (dims[self._shard_dimension] % self._number_of_shards) != 0: raise ValueError("shape %s cannot be sharded %d ways along dimension %d" % (shape.as_list(), self._number_of_shards, diff --git a/tensorflow/contrib/tpu/python/tpu/util.py b/tensorflow/contrib/tpu/python/tpu/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b8ea307d8900cf1b6d1e6e808d0b9ede26f86490 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/util.py @@ -0,0 +1,31 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== + +"""Utilities for the functionalities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + + +def check_positive_integer(value, name): + """Checks whether `value` is a positive integer.""" + if not isinstance(value, six.integer_types): + raise TypeError('{} must be int, got {}'.format(name, type(value))) + + if value <= 0: + raise ValueError('{} must be positive, got {}'.format(name, value)) diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 8e3d869a51c440e00059851f05f6ed2fe5558416..80a5debe996a330d64e62ce430d33d4111ee8767 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -263,6 +263,7 @@ py_test( srcs = ["python/training/training_test.py"], shard_count = 3, srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":training_py", "//tensorflow/contrib/framework:framework_py", diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py index f6237872cce9be57809b12f8f5067646f328cb96..2a0ef0e6b3750b4f0464f1f4390819e1fc2c7872 100644 --- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py +++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -527,6 +528,50 @@ class PaddingTest(test.TestCase): self.assertTrue( math_ops.reduce_all(math_ops.equal(val, padded_seq[key])).eval()) + def testPaddingOnlySparse(self): + ind1 = np.array([[0], [2]]) + val1 = np.array([3, 4]) + shape1 = np.array([4]) + + ind2 = np.array([[1], [2]]) + val2 = np.array([9, 12]) + shape2 = np.array([5]) + + with ops.Graph().as_default() as g, self.test_session(graph=g): + sp_tensor1 = sparse_tensor.SparseTensor( + indices=array_ops.constant(ind1, dtypes.int64), + values=array_ops.constant(val1, dtypes.int64), + dense_shape=array_ops.constant(shape1, dtypes.int64)) + sp_tensor2 = sparse_tensor.SparseTensor( + indices=array_ops.constant(ind2, dtypes.int64), + values=array_ops.constant(val2, dtypes.int64), + dense_shape=array_ops.constant(shape2, dtypes.int64)) + + sp_tensor1_expected = sparse_tensor.SparseTensor( + indices=sp_tensor1.indices, + values=sp_tensor1.values, + dense_shape=[8]) + sp_tensor2_expected = sparse_tensor.SparseTensor( + indices=sp_tensor2.indices, + values=sp_tensor2.values, + dense_shape=[8]) + + sequences = { + "key_1": sp_tensor1, + "key_2": sp_tensor2, + } + _, padded_seq = sqss._padding(sequences, 4) + + expected_padded_seq = { + "key_1": sp_tensor1_expected, + "key_2": sp_tensor2_expected, + } + + for key, val in expected_padded_seq.items(): + self.assertAllEqual( + sparse_ops.sparse_tensor_to_dense(val).eval(), + sparse_ops.sparse_tensor_to_dense(padded_seq[key]).eval()) + class SparseTensorReConstructionTest(test.TestCase): diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 119fa3824bd77724471768980783e105d5595c4b..c95a73ce4492caa81cc6b902a782717de06c1b63 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -138,7 +138,7 @@ def _process_list_value(name, parse_fn, var_type, m_dict, values, def parse_values(values, type_map): - """Parses hyperparameter values from a string into a python map.. + """Parses hyperparameter values from a string into a python map. `values` is a string containing comma-separated `name=value` pairs. For each pair, the value of the hyperparameter named `name` is set to diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py index 778cf985cada74458ff8022b3af56f1047bf46b2..72231948856b38edd3d022a99a62e6d4c8c5649e 100644 --- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py +++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py @@ -1596,7 +1596,7 @@ def _padding(sequences, num_unroll): else: # Only have SparseTensors sparse_lengths = [value.dense_shape[0] for value in sequences_dict.values() if isinstance(value, sparse_tensor.SparseTensor)] - length = math_ops.maximum(sparse_lengths) + length = math_ops.reduce_max(math_ops.to_int32(sparse_lengths)) unroll = array_ops.constant(num_unroll) padded_length = length + ((unroll - (length % unroll)) % unroll) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c1b103c98b7d10abe37461930409ab457ea66c2a..a0c8fae69a02a27b0fce4515c93dfb16bb51b26e 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -651,14 +651,15 @@ cc_library( ":image_ops_op_lib", ":io_ops_op_lib", ":linalg_ops_op_lib", - ":lookup_ops_op_lib", ":logging_ops_op_lib", + ":lookup_ops_op_lib", ":math_ops_op_lib", ":nn_ops_op_lib", ":no_op_op_lib", ":parsing_ops_op_lib", ":random_ops_op_lib", ":remote_fused_graph_ops_op_lib", + ":resource_variable_ops_op_lib", ":script_ops_op_lib", ":sdca_ops_op_lib", ":sendrecv_ops_op_lib", @@ -780,6 +781,7 @@ cc_library( "//tensorflow/core/kernels:dataset_ops", "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:histogram_op", "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", "//tensorflow/core/kernels:linalg", @@ -888,6 +890,7 @@ cc_library( ":test", "//tensorflow/cc:scope", "//tensorflow/core/kernels:constant_op", + "//tensorflow/core/kernels:ops_testutil", "//tensorflow/core/kernels:ops_util", ], ) @@ -1772,6 +1775,7 @@ tf_cuda_library( ) + if_mkl( [ "//third_party/mkl:intel_binary_blob", + "@mkl_dnn//:mkl_dnn", ], ), alwayslink = 1, @@ -1932,11 +1936,12 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/visitable_allocator.h", "graph/gradients.h", "graph/quantize_training.h", -] +] + if_mkl(["graph/mkl_graph_util.h"]) tf_cuda_library( name = "core_cpu_impl", srcs = [ + "common_runtime/accumulate_n_optimizer.cc", "common_runtime/allocator_retry.cc", "common_runtime/bfc_allocator.cc", "common_runtime/build_graph_options.cc", @@ -2033,7 +2038,10 @@ tf_cuda_library( "//third_party/eigen3", "//tensorflow/core/kernels:required", ] + if_mkl( - ["//third_party/mkl:intel_binary_blob"], + [ + "//third_party/mkl:intel_binary_blob", + "@mkl_dnn//:mkl_dnn", + ], ) + tf_additional_core_deps() + if_static([":core_cpu_impl"]), alwayslink = 1, ) @@ -2117,6 +2125,7 @@ GPU_RUNTIME_HEADERS = [ "common_runtime/gpu/gpu_debug_allocator.h", "common_runtime/gpu/gpu_device.h", "common_runtime/gpu/gpu_init.h", + "common_runtime/gpu/gpu_managed_allocator.h", "common_runtime/gpu/gpu_stream_util.h", "common_runtime/gpu/gpu_util.h", "common_runtime/gpu/pool_allocator.h", @@ -2131,6 +2140,7 @@ tf_cuda_library( "common_runtime/gpu/gpu_debug_allocator.cc", "common_runtime/gpu/gpu_device.cc", "common_runtime/gpu/gpu_device_factory.cc", + "common_runtime/gpu/gpu_managed_allocator.cc", "common_runtime/gpu/gpu_stream_util.cc", "common_runtime/gpu/gpu_util.cc", "common_runtime/gpu/gpu_util_platform_specific.cc", @@ -2662,6 +2672,22 @@ tf_cc_tests( ], ) +tf_cc_test_mkl( + name = "mkl_runtime_tests", + size = "small", + srcs = ["common_runtime/mkl_cpu_allocator_test.cc"], + linkstatic = 1, + deps = [ + ":core", + ":core_cpu", + ":framework", + ":framework_internal", + ":test", + ":test_main", + ":testlib", + ], +) + tf_cc_test_mkl( name = "mkl_related_tests", size = "small", @@ -2669,7 +2695,7 @@ tf_cc_test_mkl( "graph/mkl_layout_pass_test.cc", "graph/mkl_tfconversion_pass_test.cc", ], - linkstatic = tf_kernel_tests_linkstatic(), + linkstatic = 1, deps = [ ":core", ":core_cpu", @@ -2687,6 +2713,9 @@ tf_cc_test_mkl( "//tensorflow/cc:cc_ops", "//tensorflow/cc:scope", "//tensorflow/cc:sendrecv_ops", + "//tensorflow/core/kernels:ops_util", + "//third_party/eigen3", + ] + if_mkl([ "//tensorflow/core/kernels:mkl_aggregate_ops", "//tensorflow/core/kernels:mkl_concat_op", "//tensorflow/core/kernels:mkl_conv_op", @@ -2699,9 +2728,7 @@ tf_cc_test_mkl( "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", "//tensorflow/core/kernels:mkl_tfconv_op", - "//tensorflow/core/kernels:ops_util", - "//third_party/eigen3", - ], + ]), ) tf_cc_tests_gpu( @@ -3323,6 +3350,36 @@ tf_cc_test( ], ) +filegroup( + name = "base_api_def", + data = glob(["api_def/base_api/*"]), +) + +tf_cc_test( + name = "api_test", + srcs = ["api_def/api_test.cc"], + data = [ + ":base_api_def", + "//tensorflow/cc:ops/op_gen_overrides.pbtxt", + ], + tags = [ + "manual", + "notap", + ], + deps = [ + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":lib_test_internal", + ":op_gen_lib", + ":op_gen_overrides_proto_cc", + ":ops", + ":protos_all_cc", + ":test", + ], +) + tf_cc_test_gpu( name = "gpu_tracer_test", size = "small", diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ceeb172fa0a9abf2ab7adcfc801b4bcb5fa04381 --- /dev/null +++ b/tensorflow/core/api_def/api_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Test that verifies tensorflow/core/api_def/base_api/api_def*.pbtxt files +// are correct. If api_def*.pbtxt do not match expected contents, run +// tensorflow/core/api_def/base_api/update_api_def.sh script to update them. + +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/framework/op_gen_overrides.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { +constexpr char kDefaultApiDefDir[] = + "tensorflow/core/api_def/base_api"; +constexpr char kOverridesFilePath[] = + "tensorflow/cc/ops/op_gen_overrides.pbtxt"; +constexpr char kApiDefFileFormat[] = "api_def_%c.pbtxt"; +constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + +// Get map from first character to ApiDefs for ops +// that start with that character. +std::unordered_map GenerateApiDef( + const OpList& ops, const OpGenOverrides& overrides) { + std::unordered_map name_to_override; + for (const auto& op_override : overrides.op()) { + name_to_override[op_override.name()] = op_override; + } + + std::unordered_map api_defs_map; + + for (const auto& op : ops.op()) { + CHECK(!op.name().empty()) + << "Encountered empty op name: %s" << op.DebugString(); + const char file_id = toupper(op.name()[0]); + CHECK(isalpha(file_id)) << "Unexpected op name: " << op.name(); + ApiDef* api_def = api_defs_map[file_id].add_op(); + api_def->set_graph_op_name(op.name()); + + if (name_to_override.find(op.name()) != name_to_override.end()) { + const auto& op_override = name_to_override[op.name()]; + // Set visibility + if (op_override.skip()) { + api_def->set_visibility(ApiDef_Visibility_SKIP); + } else if (op_override.hide()) { + api_def->set_visibility(ApiDef_Visibility_HIDDEN); + } + // Add endpoints + if (!op_override.rename_to().empty()) { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(op_override.rename_to()); + } else { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(op.name()); + } + for (auto& alias : op_override.alias()) { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(alias); + } + // Add attributes + for (auto& attr : op.attr()) { + auto* api_def_attr = api_def->add_attr(); + api_def_attr->set_name(attr.name()); + for (auto& attr_override : op_override.attr_default()) { + if (attr.name() == attr_override.name()) { + *(api_def_attr->mutable_default_value()) = attr_override.value(); + } + } + for (auto& attr_rename : op_override.attr_rename()) { + if (attr.name() == attr_rename.from()) { + api_def_attr->set_rename_to(attr_rename.to()); + } + } + } + } else { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(op.name()); + } + // Add docs + api_def->set_summary(op.summary()); + api_def->set_description(op.description()); + } + return api_defs_map; +} + +// Reads golden api defs file with the given suffix. +string GetGoldenApiDefsStr(Env* env, const string& api_files_dir, char suffix) { + string file_path = strings::Printf( + io::JoinPath(api_files_dir, kApiDefFileFormat).c_str(), suffix); + if (env->FileExists(file_path).ok()) { + string file_contents; + TF_EXPECT_OK(ReadFileToString(env, file_path, &file_contents)); + return file_contents; + } + return ""; +} + +void RunApiTest(bool update_api_def, const string& api_files_dir) { + // Read C++ overrides file + string overrides_file_contents; + Env* env = Env::Default(); + TF_EXPECT_OK( + ReadFileToString(env, kOverridesFilePath, &overrides_file_contents)); + + // Read all ops + OpList ops; + OpRegistry::Global()->Export(false, &ops); + const std::vector multi_line_fields = {"description"}; + + // Get expected ApiDefs + OpGenOverrides overrides; + auto new_api_defs_map = GenerateApiDef(ops, overrides); + + bool updated_at_least_one_file = false; + + for (char c : kAlphabet) { + string golden_api_defs_str = GetGoldenApiDefsStr(env, api_files_dir, c); + string new_api_defs_str = new_api_defs_map[c].DebugString(); + new_api_defs_str = PBTxtToMultiline(new_api_defs_str, multi_line_fields); + if (golden_api_defs_str == new_api_defs_str) { + continue; + } + if (update_api_def) { + string output_file_path = + io::JoinPath(api_files_dir, strings::Printf(kApiDefFileFormat, c)); + if (new_api_defs_str.empty()) { + std::cout << "Deleting " << output_file_path << "..." << std::endl; + TF_EXPECT_OK(env->DeleteFile(output_file_path)); + } else { + std::cout << "Updating " << output_file_path << "..." << std::endl; + TF_EXPECT_OK( + WriteStringToFile(env, output_file_path, new_api_defs_str)); + } + updated_at_least_one_file = true; + } else { + EXPECT_EQ(golden_api_defs_str, new_api_defs_str) + << "To update golden API files, run " + << "tensorflow/core/api_def/update_api_def.sh."; + } + } + + if (update_api_def && !updated_at_least_one_file) { + std::cout << "Api def files are already up to date." << std::endl; + } +} + +TEST(ApiTest, GenerateBaseAPIDef) { RunApiTest(false, kDefaultApiDefDir); } +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { + bool update_api_def = false; + tensorflow::string api_files_dir = tensorflow::kDefaultApiDefDir; + std::vector flag_list = { + tensorflow::Flag( + "update_api_def", &update_api_def, + "Whether to update tensorflow/core/api_def/base_api/api_def*.pbtxt " + "files if they differ from expected API."), + tensorflow::Flag("api_def_dir", &api_files_dir, + "Base directory of api_def*.pbtxt files.")}; + std::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parsed_values_ok) { + std::cerr << usage << std::endl; + return 2; + } + if (update_api_def) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + tensorflow::RunApiTest(update_api_def, api_files_dir); + return 0; + } + testing::InitGoogleTest(&argc, argv); + // Run tests + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/core/api_def/base_api/api_def_A.pbtxt b/tensorflow/core/api_def/base_api/api_def_A.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8193d1bc624535c7894430284686e8664fb71a2d --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_A.pbtxt @@ -0,0 +1,670 @@ +op { + graph_op_name: "Abort" + endpoint { + name: "Abort" + } + summary: "Raise a exception to abort the process when called." + description: <= 2." +} +op { + graph_op_name: "AdjustContrastv2" + endpoint { + name: "AdjustContrastv2" + } + summary: "Adjust the contrast of one or more images." + description: < [2.0132, 1.056] +``` + +@compatibility(numpy) +Equivalent to np.angle. +@end_compatibility +END +} +op { + graph_op_name: "Any" + endpoint { + name: "Any" + } + summary: "Computes the \"logical or\" of elements across dimensions of a tensor." + description: < l1 else 0.0 +accum = accum_new +END +} +op { + graph_op_name: "ApplyFtrlV2" + endpoint { + name: "ApplyFtrlV2" + } + summary: "Update \'*var\' according to the Ftrl-proximal scheme." + description: < l1 else 0.0 +accum = accum_new +END +} +op { + graph_op_name: "ApplyGradientDescent" + endpoint { + name: "ApplyGradientDescent" + } + summary: "Update \'*var\' by subtracting \'alpha\' * \'delta\' from it." +} +op { + graph_op_name: "ApplyMomentum" + endpoint { + name: "ApplyMomentum" + } + summary: "Update \'*var\' according to the momentum scheme. Set use_nesterov = True if you" + description: < threshold`) +or and `false` otherwise. + +This operation is useful for Locality-Sensitive-Hashing (LSH) and other +algorithms that use hashing approximations of cosine and `L2` distances; +codes can be generated from an input via: + +```python +codebook_size = 50 +codebook_bits = codebook_size * 32 +codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], + dtype=x.dtype, + initializer=tf.orthogonal_initializer()) +codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) +codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 +# now codes has shape x.shape[:-1] + [codebook_size] +``` + +**NOTE**: Currently, the innermost dimension of the tensor must be divisible +by 8. + +Given an `input` shaped `[s0, s1, ..., s_n]`, the output is +a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. +END +} +op { + graph_op_name: "Complex" + endpoint { + name: "Complex" + } + summary: "Converts two real numbers to a complex number." + description: < [[2.25 + 4.75j], [3.25 + 5.75j]] +``` +END +} +op { + graph_op_name: "ComplexAbs" + endpoint { + name: "ComplexAbs" + } + summary: "Computes the complex absolute value of a tensor." + description: < [0, 0, 0], [0, 2, 0], [0, 5, 0] +``` + +This is typically used by gradient computations for a concat operation. +END +} +op { + graph_op_name: "ConcatV2" + endpoint { + name: "ConcatV2" + } + summary: "Concatenates tensors along one dimension." +} +op { + graph_op_name: "ConcatenateDataset" + endpoint { + name: "ConcatenateDataset" + } + summary: "Creates a dataset that concatenates `input_dataset` with `another_dataset`." +} +op { + graph_op_name: "ConditionalAccumulator" + endpoint { + name: "ConditionalAccumulator" + } + summary: "A conditional accumulator for aggregating gradients." + description: < [-2.25 - 4.75j, 3.25 - 5.75j] +``` +END +} +op { + graph_op_name: "Const" + endpoint { + name: "Const" + } + summary: "Returns a constant tensor." +} +op { + graph_op_name: "ControlTrigger" + endpoint { + name: "ControlTrigger" + } + summary: "Does nothing. Serves as a control trigger for scheduling." + description: < [a, a * b, a * b * c] +``` + +By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +performed instead: + +```python +tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] +``` + +By setting the `reverse` kwarg to `True`, the cumprod is performed in the +opposite direction: + +```python +tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] +``` + +This is more efficient than using separate `tf.reverse` ops. + +The `reverse` and `exclusive` kwargs can also be combined: + +```python +tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +``` +END +} +op { + graph_op_name: "Cumsum" + endpoint { + name: "Cumsum" + } + summary: "Compute the cumulative sum of the tensor `x` along `axis`." + description: < [a, a + b, a + b + c] +``` + +By setting the `exclusive` kwarg to `True`, an exclusive cumsum is +performed instead: + +```python +tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] +``` + +By setting the `reverse` kwarg to `True`, the cumsum is performed in the +opposite direction: + +```python +tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] +``` + +This is more efficient than using separate `tf.reverse` ops. + +The `reverse` and `exclusive` kwargs can also be combined: + +```python +tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] +``` +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_D.pbtxt b/tensorflow/core/api_def/base_api/api_def_D.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ff8a7223c7223f5c4b72ffc7154b7fc77d8eeb06 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_D.pbtxt @@ -0,0 +1,790 @@ +op { + graph_op_name: "DebugGradientIdentity" + endpoint { + name: "DebugGradientIdentity" + } + summary: "Identity op for gradient debugging." + description: <